torchebm.interpolants ¶
Stochastic interpolants for generative modeling.
Interpolants define conditional probability paths between source (noise) and target (data) distributions, parameterized by schedules α(t) and σ(t).
CosineInterpolant ¶
Bases: BaseInterpolant
Cosine (geodesic variance preserving) interpolant.
Also known as the GVP interpolant. Uses trigonometric functions to maintain unit variance throughout the interpolation path.
The interpolation is defined as:
This satisfies \(\alpha(t)^2 + \sigma(t)^2 = 1\).
Example
Source code in torchebm/interpolants/cosine.py
compute_alpha_t(t) ¶
Compute \(\alpha(t) = \sin(\pi t / 2)\) and its derivative.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (α(t), α̇(t)). |
Source code in torchebm/interpolants/cosine.py
compute_d_alpha_alpha_ratio_t(t) ¶
Compute \(\dot{\alpha}(t) / \alpha(t) = (\pi/2) \cot(\pi t / 2)\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tensor | The ratio with clamping for stability. |
Source code in torchebm/interpolants/cosine.py
compute_sigma_t(t) ¶
Compute \(\sigma(t) = \cos(\pi t / 2)\) and its derivative.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (σ(t), σ̇(t)). |
Source code in torchebm/interpolants/cosine.py
LinearInterpolant ¶
Bases: BaseInterpolant
Linear interpolant between noise and data distributions.
Also known as the optimal transport (OT) or rectified flow interpolant.
The interpolation is defined as:
with \(\alpha(t) = t\) and \(\sigma(t) = 1 - t\).
Example
Source code in torchebm/interpolants/linear.py
compute_alpha_t(t) ¶
Compute \(\alpha(t) = t\) and \(\dot{\alpha}(t) = 1\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (α(t), α̇(t)). |
Source code in torchebm/interpolants/linear.py
compute_d_alpha_alpha_ratio_t(t) ¶
Compute \(\dot{\alpha}(t) / \alpha(t) = 1/t\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tensor | The ratio 1/t with clamping for stability. |
Source code in torchebm/interpolants/linear.py
compute_sigma_t(t) ¶
Compute \(\sigma(t) = 1 - t\) and \(\dot{\sigma}(t) = -1\).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (σ(t), σ̇(t)). |
Source code in torchebm/interpolants/linear.py
VariancePreservingInterpolant ¶
Bases: BaseInterpolant
Variance preserving (VP) interpolant with linear beta schedule.
Corresponds to the noise schedule used in DDPM and score-based diffusion models.
The forward process is defined via:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sigma_min | float | Minimum noise level (default: 0.1). | 0.1 |
sigma_max | float | Maximum noise level (default: 20.0). | 20.0 |
Example
Source code in torchebm/interpolants/variance_preserving.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | |
compute_alpha_t(t) ¶
Compute \(\alpha(t)\) and its derivative for VP schedule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (α(t), α̇(t)). |
Source code in torchebm/interpolants/variance_preserving.py
compute_d_alpha_alpha_ratio_t(t) ¶
Compute \(\dot{\alpha}(t) / \alpha(t)\) directly from log mean coefficient.
This is more numerically stable than dividing α̇ by α.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tensor | The ratio (which equals d_log_mean_coeff). |
Source code in torchebm/interpolants/variance_preserving.py
compute_drift(x, t) ¶
Compute drift for VP schedule using the beta parameterization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Current state of shape (batch_size, ...). | required |
t | Tensor | Time values of shape (batch_size,). | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (drift_mean, drift_var). |
Source code in torchebm/interpolants/variance_preserving.py
compute_sigma_t(t) ¶
Compute \(\sigma(t)\) and its derivative for VP schedule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t | Tensor | Time tensor. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor] | Tuple of (σ(t), σ̇(t)). |