Getting Started with TorchEBM¶
This guide provides a hands-on introduction to TorchEBM. You'll learn how to install the library, understand its core components, and train your first Energy-Based Model (EBM) on a synthetic dataset.
1. Installation¶
TorchEBM can be installed from PyPI. Ensure you have PyTorch installed first.
Prerequisites
- Python 3.8+
- PyTorch 1.10.0+
- CUDA is optional but highly recommended for performance.
2. The Core Concepts¶
An Energy-Based Model defines a probability distribution over data \(x\) through an energy function \(E(x)\). The probability is defined as \(p(x) = \frac{e^{-E(x)}}{Z}\), where lower energy corresponds to higher probability.
TorchEBM is built around two key components:
- Energy Functions: These are learnable functions (often neural networks) that map input data to a scalar energy value.
- Samplers: These are algorithms, typically based on Markov Chain Monte Carlo (MCMC), used to draw samples from the probability distribution defined by the energy function.
Let's explore these concepts with code.
Concept 1: The Energy Function¶
An energy function is a torch.nn.Module that takes a tensor x of shape (batch_size, *dims) and returns a tensor of energy values of shape (batch_size,).
TorchEBM provides several pre-built energy functions for testing and experimentation. Here's how to use the GaussianEnergy function, which models a multivariate normal distribution.
The point [0.0, 0.0] is the mean of the distribution and thus has the lowest energy. As points move away from the mean, their energy increases.
Concept 2: The Sampler¶
Samplers generate data points from the distribution defined by an energy function. They typically work by starting from random initial points and iteratively refining them to have lower energy (higher probability).
Let's use the LangevinDynamics sampler to draw samples from our GaussianEnergy distribution.
You have now sampled from your first energy-based model! These samples approximate a 2D Gaussian distribution.
3. Training Your First EBM¶
Now let's put everything together and train an EBM with a neural network as the energy function. The goal is to train the model to represent a synthetic "two moons" dataset.
Step 1: Create a Dataset¶
First, we'll generate a TwoMoonsDataset and create a DataLoader to iterate through it in batches.
Step 2: Define a Neural Energy Function¶
Next, we'll create a simple Multi-Layer Perceptron (MLP) to serve as our energy function. This network will take 2D points as input and output a single energy value for each.
Step 3: Set up the Training Components¶
To train the EBM, we need three things:
- A Loss Function: We'll use
ContrastiveDivergence, a standard loss function for EBMs. It works by pushing down the energy of real data ("positive" samples) and pushing up the energy of generated data ("negative" samples). - A Sampler: The loss function needs a sampler to generate the negative samples. We'll use
LangevinDynamicsagain. - An Optimizer: A standard PyTorch optimizer like
Adam.
Step 4: The Training Loop¶
Now we'll write a standard PyTorch training loop. For each batch of real data, we calculate the contrastive divergence loss and update the model's weights.
Next Steps¶
Congratulations on training your first Energy-Based Model with TorchEBM!
- Learn more about the different Samplers available.
- Explore other Loss Functions for training EBMs.
- See how to create Custom Neural Networks for more complex energy functions.
- Check out the Visualization guide to see how you can plot your energy landscapes and samples.