Models and Energy Functions¶
At the core of any energy-based model (EBM) is the energy function, \( E_{\theta}(x) \), which assigns a scalar energy value to each data point \( x \). This function is used to define a probability distribution \( p_{\theta}(x) = \frac{e^{-E_{\theta}(x)}}{Z(\theta)} \), where regions of low energy correspond to high probability.
In TorchEBM, all energy functions are implemented as torch.nn.Module subclasses that inherit from the torchebm.core.BaseModel class.
Defining a Custom Model¶
You can create a custom energy function by subclassing BaseModel and implementing the forward() method. Here is an example of a simple energy function based on a Multi-Layer Perceptron (MLP).
Built-in Analytical Models¶
TorchEBM also provides several pre-built analytical models for common distributions and testing scenarios. These are useful for research and for understanding the behavior of samplers and training algorithms.
GaussianModel¶
This model implements the energy function for a multivariate Gaussian distribution.
DoubleWellModel¶
This model creates a double-well potential, which is useful for testing a sampler's ability to cross energy barriers.
Visualizing Energy Landscapes¶
Understanding the shape of the energy landscape is crucial. Here's how you can visualize the 2D landscape of the DoubleWellModel.
TorchEBM includes a variety of other analytical models such as RosenbrockModel, AckleyModel, and RastriginModel which are commonly used for benchmarking optimization and sampling algorithms. You can visualize them using the same technique.
