TorchEBM provides components for 🔬 sampling, 🧠 inference, and 📊 model training.
What is TorchEBM?¶
Energy-Based Models (EBMs) offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or "energy") to each data point. Lower energy corresponds to higher probability.
TorchEBM simplifies working with EBMs in PyTorch. It provides a suite of tools designed for researchers and practitioners, enabling efficient implementation and exploration of:
- Defining complex energy functions: Easily create custom energy landscapes using PyTorch modules.
- Training: Loss functions and procedures suitable for EBM parameter estimation including score matching and contrastive divergence variants.
- Sampling: Algorithms to draw samples from the learned distribution \( p(x) \).
Core Components¶
TorchEBM is structured around several key components:
-
Energy Functions
Implement energy functions using
BaseEnergyFunction
. Includes predefined analytical functions (Gaussian, Double Well) and supports custom neural network architectures. -
Samplers
MCMC samplers like Langevin Dynamics (
LangevinDynamics
), Hamiltonian Monte Carlo, and more are provided for generating samples from the energy distribution. -
Loss Functions
Comprehensive loss functions for EBM training, including Contrastive Divergence, Score Matching, and Noise Contrastive Estimation.
-
Datasets
Helper functions to generate synthetic datasets (e.g.,
make_gaussian_mixture
) useful for testing, debugging, and visualization purposes. -
Visualization
Tools for visualizing energy landscapes, sampling processes, and training progression to better understand model behavior.
-
Accelerated Computing
CUDA implementations of key algorithms for dramatically faster sampling and training on GPU hardware.
Quick Start¶
Install the library using pip:
Here's a minimal example of defining an energy function and a sampler:
-
Create and Sample from Energy Models
import torch from torchebm.core import GaussianEnergy from torchebm.samplers import LangevinDynamics device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define an (analytical) energy function -> next example: trainable energy_fn = GaussianEnergy(mean=torch.zeros(2), cov=torch.eye(2), device=device) # Define a sampler sampler = LangevinDynamics(energy_function=energy_fn, step_size=0.01, device=device) # Generate samples initial_points = torch.randn(500, 2, device=device) samples = sampler.sample(x=initial_points, n_steps=100) print(f"Output shape: {samples.shape}") # Output shape: torch.Size([500, 2])
Training and Visualization Example¶
Training EBMs typically involves adjusting the energy function's parameters so that observed data points have lower energy than samples generated by the model. Contrastive Divergence (CD) is a common approach.
Here's an example of setting up training using ContrastiveDivergence
and LangevinDynamics
:
-
Train an EBM
import torch.optim as optim from torch.utils.data import DataLoader from torchebm.losses import ContrastiveDivergence from torchebm.datasets import GaussianMixtureDataset # A trainable EBM class MLPEnergy(BaseEnergyFunction): def __init__(self, input_dim, hidden_dim=64): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, 1), ) def forward(self, x): return self.net(x).squeeze(-1) # a scalar value energy_fn = MLPEnergy(input_dim=2).to(device) cd_loss_fn = ContrastiveDivergence( energy_function=energy_fn, sampler=sampler, # from the previous example n_steps=10 # MCMC steps for negative samples gen ) optimizer = optim.Adam(energy_fn.parameters(), lr=0.001) mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123).get_data() dataloader = DataLoader(mixture_dataset, batch_size=32, shuffle=True) # Training Loop for epoch in range(10): epoch_loss = 0.0 for i, batch_data in enumerate(dataloader): batch_data = batch_data.to(device) optimizer.zero_grad() loss, neg_samples = cd_loss(batch_data) loss.backward() optimizer.step() epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.6f}")
Visualizing the learned energy landscape during training can be insightful. Below shows the evolution of an MLP-based energy function trained on a 2D Gaussian mixture:
Training Progression (Gaussian Mixture Example)

This visualization demonstrates how the model learns regions of low energy (high probability density, warmer colors) corresponding to the data distribution (white points), while assigning higher energy elsewhere. Red points are samples generated from the EBM at that training stage.

Latest Release
TorchEBM is currently in early development. Check our GitHub repository for the latest updates and features.
Example Analytical Energy Landscapes¶
Toy Examples
These are some TorchEBM's built-in toy analytical energy landscapes for functionality and performance testing purposes.
Community & Contribution¶
TorchEBM is an open-source project developed with the research community in mind.
- Bug Reports & Feature Requests: Please use the GitHub Issues.
- Contributing Code: We welcome contributions! Please see the Contributing Guidelines. Consider following the Commit Conventions.
- Show Support: If you find TorchEBM helpful for your work, consider starring the repository on GitHub!
Citation¶
Please consider citing the TorchEBM repository if it contributes to your research:
@misc{torchebm_library_2025,
author = {Ghaderi, Soran and Contributors},
title = {TorchEBM: A PyTorch Library for Training Energy-Based Models},
year = {2025},
url = {https://github.com/soran-ghaderi/torchebm},
}
License¶
TorchEBM is available under the MIT License. See the LICENSE file for details.