Skip to content
TorchEBM Logo

PyTorch Toolkit for Generative Modeling

A high-performance PyTorch library that makes Energy-Based Models accessible and efficient for researchers and practitioners alike.

PyPI License GitHub Stars

TorchEBM provides components for 🔬 sampling, 🧠 inference, and 📊 model training.


What is TorchEBM?

Energy-Based Models (EBMs) offer a powerful and flexible framework for generative modeling by assigning an unnormalized probability (or "energy") to each data point. Lower energy corresponds to higher probability.

TorchEBM simplifies working with EBMs in PyTorch. It provides a suite of tools designed for researchers and practitioners, enabling efficient implementation and exploration of:

  • Defining complex energy functions: Easily create custom energy landscapes using PyTorch modules.
  • Training: Loss functions and procedures suitable for EBM parameter estimation including score matching and contrastive divergence variants.
  • Sampling: Algorithms to draw samples from the learned distribution \( p(x) \).

Core Components

TorchEBM is structured around several key components:

  • Energy Functions


    Implement energy functions using BaseEnergyFunction. Includes predefined analytical functions (Gaussian, Double Well) and supports custom neural network architectures.

    Details

  • Samplers


    MCMC samplers like Langevin Dynamics (LangevinDynamics), Hamiltonian Monte Carlo, and more are provided for generating samples from the energy distribution.

    Details

  • Loss Functions


    Comprehensive loss functions for EBM training, including Contrastive Divergence, Score Matching, and Noise Contrastive Estimation.

    Details

  • Datasets


    Helper functions to generate synthetic datasets (e.g., make_gaussian_mixture) useful for testing, debugging, and visualization purposes.

    Details

  • Visualization


    Tools for visualizing energy landscapes, sampling processes, and training progression to better understand model behavior.

    Details

  • Accelerated Computing


    CUDA implementations of key algorithms for dramatically faster sampling and training on GPU hardware.

    Details


Quick Start

Install the library using pip:

pip install torchebm

Here's a minimal example of defining an energy function and a sampler:

  • Create and Sample from Energy Models


    import torch
    from torchebm.core import GaussianEnergy
    from torchebm.samplers import LangevinDynamics
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Define an (analytical) energy function -> next example: trainable
    energy_fn = GaussianEnergy(mean=torch.zeros(2), cov=torch.eye(2), device=device)
    
    # Define a sampler
    sampler = LangevinDynamics(energy_function=energy_fn, step_size=0.01, device=device)
    
    # Generate samples
    initial_points = torch.randn(500, 2, device=device)
    samples = sampler.sample(x=initial_points, n_steps=100)
    
    print(f"Output shape: {samples.shape}")
    # Output shape: torch.Size([500, 2])
    

Training and Visualization Example

Training EBMs typically involves adjusting the energy function's parameters so that observed data points have lower energy than samples generated by the model. Contrastive Divergence (CD) is a common approach.

Here's an example of setting up training using ContrastiveDivergence and LangevinDynamics:

  • Train an EBM


    import torch.optim as optim
    from torch.utils.data import DataLoader
    
    from torchebm.losses import ContrastiveDivergence
    from torchebm.datasets import GaussianMixtureDataset
    
    # A trainable EBM
    class MLPEnergy(BaseEnergyFunction):
        def __init__(self, input_dim, hidden_dim=64):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(),
                nn.Linear(hidden_dim, 1),
            )
    
        def forward(self, x):
            return self.net(x).squeeze(-1) # a scalar value
    
    energy_fn = MLPEnergy(input_dim=2).to(device)
    
    cd_loss_fn = ContrastiveDivergence(
        energy_function=energy_fn,
        sampler=sampler, # from the previous example
        n_steps=10 # MCMC steps for negative samples gen
    )
    
    optimizer = optim.Adam(energy_fn.parameters(), lr=0.001)
    
    mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123).get_data()
    dataloader = DataLoader(mixture_dataset, batch_size=32, shuffle=True)
    
    # Training Loop
    for epoch in range(10):
        epoch_loss = 0.0
        for i, batch_data in enumerate(dataloader):
            batch_data = batch_data.to(device)
    
            optimizer.zero_grad()
    
            loss, neg_samples = cd_loss(batch_data)
    
            loss.backward()
            optimizer.step()
    
            epoch_loss += loss.item()
    
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.6f}")
    

Visualizing the learned energy landscape during training can be insightful. Below shows the evolution of an MLP-based energy function trained on a 2D Gaussian mixture:

Training Progression (Gaussian Mixture Example)

Training - Epoch 100
Learned landscape matching the target distribution.

This visualization demonstrates how the model learns regions of low energy (high probability density, warmer colors) corresponding to the data distribution (white points), while assigning higher energy elsewhere. Red points are samples generated from the EBM at that training stage.

Full Training Example

Training - Epoch 10 Training - Epoch 20 Training - Epoch 30

Training - Epoch 30
Modes become more distinct.
This visualization demonstrates how the model learns regions of low energy (high probability density, warmer colors) corresponding to the data distribution (white points), while assigning higher energy elsewhere. Red points are samples generated from the EBM at that training stage.

Full Training Example

Training - Epoch 10 Training - Epoch 20 Training - Epoch 100

Training - Epoch 20
Energy landscape refinement.
This visualization demonstrates how the model learns regions of low energy (high probability density, warmer colors) corresponding to the data distribution (white points), while assigning higher energy elsewhere. Red points are samples generated from the EBM at that training stage.

Full Training Example

Training - Epoch 10 Training - Epoch 30 Training - Epoch 100

Training - Epoch 10
Early stage: Model starts identifying modes.
This visualization demonstrates how the model learns regions of low energy (high probability density, warmer colors) corresponding to the data distribution (white points), while assigning higher energy elsewhere. Red points are samples generated from the EBM at that training stage.

Full Training Example

Training - Epoch 20 Training - Epoch 30 Training - Epoch 100


Latest Release

TorchEBM is currently in early development. Check our GitHub repository for the latest updates and features.

Example Analytical Energy Landscapes

Toy Examples

These are some TorchEBM's built-in toy analytical energy landscapes for functionality and performance testing purposes.

Gaussian Energy

Gaussian Energy

\(E(x) = \frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\)

Double Well Energy Rastrigin Energy Rosenbrock Energy

from torchebm.core import GaussianEnergy
import torch

energy_fn = GaussianEnergy(
    mean=torch.zeros(2),
    cov=torch.eye(2)
)

Double Well Energy

Double Well Energy

\(E(x) = h \sum_{i=1}^n \left[(x_i^2 - 1)^2\right]\)

Gaussian Energy Rastrigin Energy Rosenbrock Energy

from torchebm.core import DoubleWellEnergy

energy_fn = DoubleWellEnergy(
    barrier_height=2.0
)

Rastrigin Energy

Rastrigin Energy

\(E(x) = an + \sum_{i=1}^n \left[ x_i^2 - a\cos(2\pi x_i) \right]\)

Gaussian Energy Double Well Energy Rosenbrock Energy

from torchebm.core import RastriginEnergy

energy_fn = RastriginEnergy(
    a=10.0
)

Rosenbrock Energy

Rosenbrock Energy

\(E(x) = \sum_{i=1}^{n-1} \left[ a(x_{i+1} - x_i^2)^2 + (x_i - 1)^2 \right]\)

Gaussian Energy Double Well Energy Rastrigin Energy

from torchebm.core import RosenbrockEnergy

energy_fn = RosenbrockEnergy(
    a=1.0, 
    b=100.0
)

Community & Contribution

TorchEBM is an open-source project developed with the research community in mind.


Citation

Please consider citing the TorchEBM repository if it contributes to your research:

@misc{torchebm_library_2025,
  author       = {Ghaderi, Soran and Contributors},
  title        = {TorchEBM: A PyTorch Library for Training Energy-Based Models},
  year         = {2025},
  url          = {https://github.com/soran-ghaderi/torchebm},
}

License

TorchEBM is available under the MIT License. See the LICENSE file for details.