TorchEBM π: A PyTorch Framework for Energy-Based Modeling

Overview
TorchEBM is a high-performance PyTorch library that makes Energy-Based Models (EBMs) accessible and efficient for researchers and practitioners alike. The framework provides comprehensive components for π¬ sampling, π§ inference, and π model training.
Key Resources:
- π Documentation
- π¬ GitHub Repository
- π¦ PyPI Package
What are Energy-Based Models?
Energy-Based Models (EBMs) offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or βenergyβ) to each data point. Lower energy corresponds to higher probability. EBMs define a probability distribution as:
\[p(x) = \frac{e^{-E(x)}}{Z}\]where $E(x)$ is the energy function and $Z$ is the partition function.
Core Components
TorchEBM is structured around six key components, each designed for specific aspects of energy-based modeling:
1. Energy Functions
Implement energy functions using BaseEnergyFunction
. The library includes both analytical and neural network-based energy functions:
Analytical Energy Functions
TorchEBM provides several built-in analytical energy landscapes for testing and research:
- GaussianEnergy: $E(x) = \frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)$
- DoubleWellEnergy: $E(x) = h \sum_{i=1}^n [(x_i^2 - 1)^2]$
- RastriginEnergy: $E(x) = an + \sum_{i=1}^n [x_i^2 - a\cos(2\pi x_i)]$
- RosenbrockEnergy: $E(x) = \sum_{i=1}^{n-1} [a(x_{i+1} - x_i^2)^2 + (x_i - 1)^2]$
- AckleyEnergy: Complex multi-modal energy landscapes
- HarmonicEnergy: Simple quadratic potentials
from torchebm.core import GaussianEnergy, DoubleWellEnergy
import torch
# Gaussian energy function
energy_fn = GaussianEnergy(
mean=torch.zeros(2),
cov=torch.eye(2)
)
# Double well energy
double_well = DoubleWellEnergy(barrier_height=2.0)
2. Samplers
MCMC samplers for generating samples from energy distributions:
Langevin Dynamics
Implements the stochastic differential equation for sampling:
from torchebm.samplers import LangevinDynamics
sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.01,
device=device
)
# Generate samples
initial_points = torch.randn(500, 2, device=device)
samples = sampler.sample(x=initial_points, n_steps=100)
Hamiltonian Monte Carlo
For more efficient sampling using momentum:
from torchebm.samplers import HamiltonianMonteCarlo
hmc_sampler = HamiltonianMonteCarlo(
energy_function=energy_fn,
step_size=0.1,
n_leapfrog_steps=10,
device=device
)
3. Loss Functions
Comprehensive loss functions for EBM training:
Contrastive Divergence
ContrastiveDivergence
: Standard CD algorithmPersistentContrastiveDivergence
: Persistent CD for better trainingParallelTemperingCD
: Enhanced with parallel tempering
Score Matching
ScoreMatching
: Standard score matchingSlicedScoreMatching
: Scalable variant for high dimensionsDenosingScoreMatching
: Denoising score matching
4. Datasets
Helper functions for generating synthetic datasets useful for testing and visualization:
from torchebm.datasets import (
CheckerboardDataset, CircleDataset, EightGaussiansDataset,
GaussianMixtureDataset, GridDataset, PinwheelDataset,
SwissRollDataset, TwoMoonsDataset
)
# Create a Gaussian mixture dataset
dataset = GaussianMixtureDataset(n_samples=1000, n_components=4)
data = dataset.get_data()
5. Visualization
Tools for visualizing energy landscapes, sampling processes, and training progression:
from torchebm.utils import Visualization
# Visualize energy landscape and samples
Visualization.plot_energy_landscape(energy_fn, samples)
6. CUDA Acceleration
GPU implementations of key algorithms for dramatically faster sampling and training on CUDA hardware.
Quick Start Example
Hereβs a complete example of creating and sampling from an energy model:
import torch
from torchebm.core import GaussianEnergy
from torchebm.samplers import LangevinDynamics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define an analytical energy function
energy_fn = GaussianEnergy(
mean=torch.zeros(2),
cov=torch.eye(2),
device=device
)
# Define a sampler
sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.01,
device=device
)
# Generate samples
initial_points = torch.randn(500, 2, device=device)
samples = sampler.sample(x=initial_points, n_steps=100)
print(f"Output batch_shape: {samples.shape}")
# Output batch_shape: torch.Size([500, 2])
Training Energy-Based Models
Training EBMs typically involves adjusting the energy functionβs parameters so that observed data points have lower energy than samples generated by the model. Hereβs an example using Contrastive Divergence:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchebm.core import BaseEnergyFunction
from torchebm.samplers import LangevinDynamics
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset
# Define a neural energy function
class MLPEnergy(BaseEnergyFunction):
def __init__(self, input_dim: int, hidden_dim: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x).squeeze(-1)
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
energy_model = MLPEnergy(input_dim=2).to(device)
# Sampler for negative sampling
sampler = LangevinDynamics(
energy_function=energy_model,
step_size=0.01,
device=device
)
# Loss function
cd_loss = ContrastiveDivergence(
energy_function=energy_model,
sampler=sampler,
k_steps=10
)
# Training setup
optimizer = optim.Adam(energy_model.parameters(), lr=1e-3)
dataset = GaussianMixtureDataset(n_samples=5000, n_components=4)
dataloader = DataLoader(dataset.get_data(), batch_size=64, shuffle=True)
# Training loop
for epoch in range(100):
epoch_loss = 0.0
for batch_data in dataloader:
batch_data = batch_data.to(device)
optimizer.zero_grad()
loss = cd_loss(batch_data)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: Loss = {epoch_loss/len(dataloader):.6f}")
Training Visualization
The library provides excellent visualization capabilities for understanding model training. Below shows the evolution of an MLP-based energy function trained on a 2D Gaussian mixture:



These visualizations demonstrate how the model learns regions of low energy (high probability density, warmer colors) corresponding to the data distribution (white points), while assigning higher energy elsewhere. Red points are samples generated from the EBM at each training stage.
Example Energy Landscapes
TorchEBM includes several analytical energy functions for testing and benchmarking:




API Reference Structure
The library is organized into several main modules:
torchebm/
βββ core/ # Base classes and energy functions
β βββ BaseEnergyFunction
β βββ GaussianEnergy, DoubleWellEnergy, etc.
β βββ BaseTrainer, BaseOptimizer
β βββ BaseScheduler variants
βββ samplers/ # MCMC sampling algorithms
β βββ LangevinDynamics
β βββ HamiltonianMonteCarlo
βββ losses/ # Training objectives
β βββ ContrastiveDivergence variants
β βββ ScoreMatching variants
βββ datasets/ # Synthetic data generators
β βββ Various 2D datasets
βββ models/ # Neural network architectures
βββ utils/ # Visualization and utilities
Installation
Install TorchEBM using pip:
pip install torchebm
For the latest development version:
git clone https://github.com/soran-ghaderi/torchebm.git
cd torchebm
pip install -e .
Examples and Tutorials
The library includes comprehensive examples for:
- Energy Functions: Working with analytical and neural energy functions
- Datasets: Generating and using synthetic datasets
- Samplers: Langevin Dynamics and Hamiltonian Monte Carlo tutorials
- Training EBMs: Complete training workflows for learning Gaussian mixtures
- Visualization: Creating energy landscape plots and training progression
Visit the Examples Section for detailed tutorials and code examples.
System Requirements
- Python: β₯ 3.8
- PyTorch: β₯ 1.9.0
- CUDA: β₯ 11.0 (optional, for GPU acceleration)
Contributing
TorchEBM is an open-source project welcoming contributions:
- π Bug Reports & Feature Requests
- π‘ Contributing Guidelines
- π Pull Requests
Citation
If you use TorchEBM in your research, please cite:
@misc{torchebm_library_2025,
author = {Ghaderi, Soran and Contributors},
title = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
year = {2025},
url = {https://github.com/soran-ghaderi/torchebm},
}
License
TorchEBM is available under the MIT License. See the LICENSE file for details.
Documentation: soran-ghaderi.github.io/torchebm
Maintainer: Soran Ghaderi