Training an EBM on a Gaussian Mixture¶
This tutorial demonstrates how to train an energy-based model (EBM) on a 2D Gaussian mixture distribution using the TorchEBM library. We'll build a simple MLP-based energy function, train it with Contrastive Divergence, and visualize the results.
Key Concepts Covered
- Building an MLP-based energy function
- Training with Contrastive Divergence
- Sampling with Langevin dynamics
- Visualizing the energy landscape and samples
Overview¶
Energy-based models provide a flexible framework for modeling complex probability distributions. This tutorial focuses on a simple but illustrative example: learning a 2D Gaussian mixture distribution. This is a good starting point because:
- It's easy to visualize in 2D
- It has multimodal structure that challenges simple models
- We can generate synthetic training data with known properties
Prerequisites¶
Before starting, make sure you have TorchEBM installed:
We'll also use the following libraries:
Step 1: Define the Energy Function¶
We'll create a simple MLP (Multi-Layer Perceptron) energy function by subclassing BaseEnergyFunction
:
This energy function maps input points to scalar energy values. Lower energy corresponds to higher probability density.
Step 2: Create the Dataset¶
TorchEBM provides built-in datasets for testing and development. Let's use the GaussianMixtureDataset
:
We can visualize the generated data to see what our target distribution looks like:
Step 3: Define the Visualization Function¶
We'll create a function to visualize the energy landscape and generated samples during training:
Visualizing Analytical Energy Functions
For more detailed information on analytical energy-function visualizations and techniques, see our Energy Visualization Guide. You can find visualized 2D toy datasets in Datasets examples.
Step 4: Set Up the Training Components¶
Now we'll set up the model, sampler, loss function, and optimizer:
Langevin Dynamics
Langevin dynamics is a sampling method that uses gradient information to explore the energy landscape. It adds noise to the gradient updates, allowing the sampler to overcome energy barriers and explore multimodal distributions.
Persistent Contrastive Divergence
Setting persistent=True
enables Persistent Contrastive Divergence (PCD), which maintains a set of persistent chains between parameter updates. This can lead to better exploration of the energy landscape and improved training stability, especially for complex distributions.
Step 5: The Training Loop¶
Now we're ready to train our energy-based model:
Training Progress Visualization¶
As training progresses, we can see how the energy landscape evolves to capture the four-mode structure of our target distribution. The brighter regions in the contour plot represent areas of higher probability density (lower energy), and the red points show samples drawn from the model.
Step 6: Final Evaluation¶
After training, we generate a final set of samples from our model for evaluation:
Understanding the Results¶
Our model has successfully learned the four-mode structure of the target distribution. The contour plot shows four distinct regions of low energy (high probability) corresponding to the four Gaussian components.
The red points (samples from our model) closely match the distribution of the white points (real data), indicating that our energy-based model has effectively captured the target distribution.
This example demonstrates the core workflow for training energy-based models with TorchEBM:
- Define an energy function
- Set up a sampler for generating negative samples
- Use Contrastive Divergence for training
- Monitor progress through visualization
Tips for Training EBMs¶
When training your own energy-based models, consider these tips:
-
Sampling Parameters
The step size and noise scale of the Langevin dynamics sampler are critical. Too large values can lead to unstable sampling, while too small values may result in poor mixing.
-
CD Steps
The number of MCMC steps in Contrastive Divergence affects the quality of negative samples. More steps generally lead to better results but increase computation time.
-
Learning Rate
Energy-based models can be sensitive to the learning rate. Start with a smaller learning rate and gradually increase if needed.
-
Neural Network Architecture
The choice of architecture and activation function can affect the smoothness of the energy landscape.
Complete Code¶
Here's the complete code for this example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
|
Conclusion¶
In this tutorial, we've learned how to:
- Define a simple energy-based model using an MLP
- Generate synthetic data from a 2D Gaussian mixture using TorchEBM's dataset utilities
- Train the model using Contrastive Divergence and Langevin dynamics
- Visualize the energy landscape and generated samples throughout training
Energy-based models provide a powerful and flexible framework for modeling complex probability distributions. While we've focused on a simple 2D example, the same principles apply to more complex, high-dimensional distributions.