torchebm.core.base_model ¶
AckleyModel ¶
Bases: BaseModel
Energy-based model for the Ackley function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | float | The | 20.0 |
b | float | The | 0.2 |
c | float | The | 2 * pi |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Ackley energy.
Source code in torchebm/core/base_model.py
BaseModel ¶
Bases: DeviceMixin, Module, ABC
Abstract base class for energy-based models (EBMs).
This class provides a unified interface for defining EBMs, which represent the unnormalized negative log-likelihood of a probability distribution. It supports both analytical models and trainable neural networks.
Subclasses must implement the forward(x) method and can optionally override the gradient(x) method for analytical gradients.
Source code in torchebm/core/base_model.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | |
__call__(x, *args, **kwargs) ¶
Alias for the forward method for standard PyTorch module usage.
Source code in torchebm/core/base_model.py
__init__(dtype=torch.float32, use_mixed_precision=False, *args, **kwargs) ¶
Initializes the BaseModel base class.
Source code in torchebm/core/base_model.py
forward(x) abstractmethod ¶
Computes the scalar energy value for each input sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input tensor of shape (batch_size, *input_dims). | required |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: Tensor of scalar energy values with shape (batch_size,). |
Source code in torchebm/core/base_model.py
gradient(x) ¶
Computes the gradient of the energy function with respect to the input, \(\nabla_x E(x)\).
This default implementation uses torch.autograd. Subclasses can override it for analytical gradients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input tensor of shape (batch_size, *input_dims). | required |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: Gradient tensor of the same shape as |
Source code in torchebm/core/base_model.py
DoubleWellModel ¶
Bases: BaseModel
Energy-based model for a double-well potential.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
barrier_height | float | The height of the energy barrier between the wells. | 2.0 |
b | float | The position of the wells (default is 1.0, creating wells at ±1). | 1.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the double well energy: \(h \sum_{i=1}^{n} (x_i^2 - b^2)^2\).
Source code in torchebm/core/base_model.py
GaussianModel ¶
Bases: BaseModel
Energy-based model for a Gaussian distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean | Tensor | The mean vector (μ) of the Gaussian distribution. | required |
cov | Tensor | The covariance matrix (Σ) of the Gaussian distribution. | required |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Gaussian energy: \(E(x) = \frac{1}{2} (x - \mu)^{\top} \Sigma^{-1} (x - \mu)\).
Source code in torchebm/core/base_model.py
HarmonicModel ¶
Bases: BaseModel
Energy-based model for a harmonic oscillator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
k | float | The spring constant. | 1.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the harmonic oscillator energy: \(\frac{1}{2} k \sum_{i=1}^{n} x_i^{2}\).
RastriginModel ¶
Bases: BaseModel
Energy-based model for the Rastrigin function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | float | The | 10.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Rastrigin energy.
Source code in torchebm/core/base_model.py
RosenbrockModel ¶
Bases: BaseModel
Energy-based model for the Rosenbrock function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a | float | The | 1.0 |
b | float | The | 100.0 |
Source code in torchebm/core/base_model.py
forward(x) ¶
Computes the Rosenbrock energy: \(\sum_{i=1}^{n-1} \left[ b(x_{i+1} - x_i^2)^2 + (a - x_i)^2 \right]\).