Soran Ghaderi

Looking for PhD & Research Opportunities - MSc in AI at the University of Essex

TorchEBM πŸ“: A PyTorch Framework for Energy-Based Modeling



TorchEBM πŸ“: A PyTorch Framework for Energy-Based Modeling

GitHub Repo stars PyPI - Downloads latest release license

Overview

TorchEBM is a high-performance PyTorch library that makes Energy-Based Models (EBMs) accessible and efficient for researchers and practitioners alike. The framework provides comprehensive components for πŸ”¬ sampling, 🧠 inference, and πŸ“Š model training.

Key Resources:

What are Energy-Based Models?

Energy-Based Models (EBMs) offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or β€œenergy”) to each data point. Lower energy corresponds to higher probability. EBMs define a probability distribution as:

\[p(x) = \frac{e^{-E(x)}}{Z}\]

where $E(x)$ is the energy function and $Z$ is the partition function.

Core Components

TorchEBM is structured around six key components, each designed for specific aspects of energy-based modeling:

1. Energy Functions

Implement energy functions using BaseEnergyFunction. The library includes both analytical and neural network-based energy functions:

Analytical Energy Functions

TorchEBM provides several built-in analytical energy landscapes for testing and research:

  • GaussianEnergy: $E(x) = \frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)$
  • DoubleWellEnergy: $E(x) = h \sum_{i=1}^n [(x_i^2 - 1)^2]$
  • RastriginEnergy: $E(x) = an + \sum_{i=1}^n [x_i^2 - a\cos(2\pi x_i)]$
  • RosenbrockEnergy: $E(x) = \sum_{i=1}^{n-1} [a(x_{i+1} - x_i^2)^2 + (x_i - 1)^2]$
  • AckleyEnergy: Complex multi-modal energy landscapes
  • HarmonicEnergy: Simple quadratic potentials
from torchebm.core import GaussianEnergy, DoubleWellEnergy
import torch

# Gaussian energy function
energy_fn = GaussianEnergy(
    mean=torch.zeros(2),
    cov=torch.eye(2)
)

# Double well energy
double_well = DoubleWellEnergy(barrier_height=2.0)

2. Samplers

MCMC samplers for generating samples from energy distributions:

Langevin Dynamics

Implements the stochastic differential equation for sampling:

from torchebm.samplers import LangevinDynamics

sampler = LangevinDynamics(
    energy_function=energy_fn,
    step_size=0.01,
    device=device
)

# Generate samples
initial_points = torch.randn(500, 2, device=device)
samples = sampler.sample(x=initial_points, n_steps=100)

Hamiltonian Monte Carlo

For more efficient sampling using momentum:

from torchebm.samplers import HamiltonianMonteCarlo

hmc_sampler = HamiltonianMonteCarlo(
    energy_function=energy_fn,
    step_size=0.1,
    n_leapfrog_steps=10,
    device=device
)

3. Loss Functions

Comprehensive loss functions for EBM training:

Contrastive Divergence

  • ContrastiveDivergence: Standard CD algorithm
  • PersistentContrastiveDivergence: Persistent CD for better training
  • ParallelTemperingCD: Enhanced with parallel tempering

Score Matching

  • ScoreMatching: Standard score matching
  • SlicedScoreMatching: Scalable variant for high dimensions
  • DenosingScoreMatching: Denoising score matching

4. Datasets

Helper functions for generating synthetic datasets useful for testing and visualization:

from torchebm.datasets import (
    CheckerboardDataset, CircleDataset, EightGaussiansDataset,
    GaussianMixtureDataset, GridDataset, PinwheelDataset,
    SwissRollDataset, TwoMoonsDataset
)

# Create a Gaussian mixture dataset
dataset = GaussianMixtureDataset(n_samples=1000, n_components=4)
data = dataset.get_data()

5. Visualization

Tools for visualizing energy landscapes, sampling processes, and training progression:

from torchebm.utils import Visualization

# Visualize energy landscape and samples
Visualization.plot_energy_landscape(energy_fn, samples)

6. CUDA Acceleration

GPU implementations of key algorithms for dramatically faster sampling and training on CUDA hardware.

Quick Start Example

Here’s a complete example of creating and sampling from an energy model:

import torch
from torchebm.core import GaussianEnergy
from torchebm.samplers import LangevinDynamics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define an analytical energy function
energy_fn = GaussianEnergy(
    mean=torch.zeros(2), 
    cov=torch.eye(2), 
    device=device
)

# Define a sampler
sampler = LangevinDynamics(
    energy_function=energy_fn,
    step_size=0.01,
    device=device
)

# Generate samples
initial_points = torch.randn(500, 2, device=device)
samples = sampler.sample(x=initial_points, n_steps=100)

print(f"Output batch_shape: {samples.shape}")
# Output batch_shape: torch.Size([500, 2])

Training Energy-Based Models

Training EBMs typically involves adjusting the energy function’s parameters so that observed data points have lower energy than samples generated by the model. Here’s an example using Contrastive Divergence:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchebm.core import BaseEnergyFunction
from torchebm.samplers import LangevinDynamics
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import GaussianMixtureDataset

# Define a neural energy function
class MLPEnergy(BaseEnergyFunction):
    def __init__(self, input_dim: int, hidden_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x).squeeze(-1)

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
energy_model = MLPEnergy(input_dim=2).to(device)

# Sampler for negative sampling
sampler = LangevinDynamics(
    energy_function=energy_model,
    step_size=0.01,
    device=device
)

# Loss function
cd_loss = ContrastiveDivergence(
    energy_function=energy_model,
    sampler=sampler,
    k_steps=10
)

# Training setup
optimizer = optim.Adam(energy_model.parameters(), lr=1e-3)
dataset = GaussianMixtureDataset(n_samples=5000, n_components=4)
dataloader = DataLoader(dataset.get_data(), batch_size=64, shuffle=True)

# Training loop
for epoch in range(100):
    epoch_loss = 0.0
    for batch_data in dataloader:
        batch_data = batch_data.to(device)
        
        optimizer.zero_grad()
        loss = cd_loss(batch_data)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Loss = {epoch_loss/len(dataloader):.6f}")

Training Visualization

The library provides excellent visualization capabilities for understanding model training. Below shows the evolution of an MLP-based energy function trained on a 2D Gaussian mixture:

Training Epoch 10 Training Epoch 30
Training Epoch 100

These visualizations demonstrate how the model learns regions of low energy (high probability density, warmer colors) corresponding to the data distribution (white points), while assigning higher energy elsewhere. Red points are samples generated from the EBM at each training stage.

Example Energy Landscapes

TorchEBM includes several analytical energy functions for testing and benchmarking:

Gaussian Energy Double Well Energy Rastrigin Energy Rosenbrock Energy

API Reference Structure

The library is organized into several main modules:

torchebm/
β”œβ”€β”€ core/                  # Base classes and energy functions
β”‚   β”œβ”€β”€ BaseEnergyFunction
β”‚   β”œβ”€β”€ GaussianEnergy, DoubleWellEnergy, etc.
β”‚   β”œβ”€β”€ BaseTrainer, BaseOptimizer
β”‚   └── BaseScheduler variants
β”œβ”€β”€ samplers/              # MCMC sampling algorithms
β”‚   β”œβ”€β”€ LangevinDynamics
β”‚   └── HamiltonianMonteCarlo
β”œβ”€β”€ losses/                # Training objectives
β”‚   β”œβ”€β”€ ContrastiveDivergence variants
β”‚   └── ScoreMatching variants
β”œβ”€β”€ datasets/              # Synthetic data generators
β”‚   └── Various 2D datasets
β”œβ”€β”€ models/                # Neural network architectures
└── utils/                 # Visualization and utilities

Installation

Install TorchEBM using pip:

pip install torchebm

For the latest development version:

git clone https://github.com/soran-ghaderi/torchebm.git
cd torchebm
pip install -e .

Examples and Tutorials

The library includes comprehensive examples for:

  • Energy Functions: Working with analytical and neural energy functions
  • Datasets: Generating and using synthetic datasets
  • Samplers: Langevin Dynamics and Hamiltonian Monte Carlo tutorials
  • Training EBMs: Complete training workflows for learning Gaussian mixtures
  • Visualization: Creating energy landscape plots and training progression

Visit the Examples Section for detailed tutorials and code examples.

System Requirements

  • Python: β‰₯ 3.8
  • PyTorch: β‰₯ 1.9.0
  • CUDA: β‰₯ 11.0 (optional, for GPU acceleration)

Contributing

TorchEBM is an open-source project welcoming contributions:

Citation

If you use TorchEBM in your research, please cite:

@misc{torchebm_library_2025,
  author       = {Ghaderi, Soran and Contributors},
  title        = {{TorchEBM}: A PyTorch Library for Training Energy-Based Models},
  year         = {2025},
  url          = {https://github.com/soran-ghaderi/torchebm},
}

License

TorchEBM is available under the MIT License. See the LICENSE file for details.


Documentation: soran-ghaderi.github.io/torchebm

Maintainer: Soran Ghaderi