Skip to content
TorchEBM Logo

PyTorch Library for Generative Modeling

A high-performance PyTorch library that makes Energy-Based Models accessible and efficient for researchers and practitioners alike.

PyPI License GitHub Stars Ask DeepWiki Build Status Docs Downloads Python Versions

A PyTorch library for energy-based modeling, with support for flow and diffusion methods.

EBM training
Training an energy-based model to capture a target distribution.

What is TorchEBM 🍓?

Energy-based models assign a scalar energy to each input, implicitly defining a probability distribution where lower energy means higher probability. This formulation is remarkably general. MCMC sampling, score matching, contrastive divergence, and even flow/diffusion-based generation all operate within or connect naturally to the energy-based framework.

TorchEBM gives you composable PyTorch building blocks that span this entire landscape. You can define energy functions, train models with different learning objectives, and generate samples via MCMC, energy minimization, or learned continuous-time dynamics.


In Action

Equilibrium matching on eight gaussians
Eight-gaussians distribution.
Equilibrium matching on circles
Circles distribution.

Equilibrium matching with different interpolants transforming noise into structured distributions.


Core Components

  • Core


    Base classes, energy models (analytical potentials and custom neural networks), schedulers, and the device/dtype management layer shared across all components.

    API Reference

  • Samplers


    Draw samples from energy landscapes via MCMC methods, gradient-based optimization, or learned flow/diffusion dynamics (ODE/SDE).

    API Reference

  • Loss Functions


    Training objectives for energy-based and flow-based models, including contrastive divergence variants, score matching variants, and equilibrium matching.

    API Reference

  • Interpolants


    Define how noise and data are mixed along a continuous time path. Used in flow matching, diffusion, and related generative schemes.

    API Reference

  • Integrators


    Numerical solvers for SDEs, ODEs, and Hamiltonian systems. Pluggable into samplers and flow-based generation pipelines.

    API Reference

  • Models


    Neural network architectures for parameterizing energy functions and velocity fields, including vision transformers and guidance wrappers.

    API Reference

  • Datasets


    Synthetic 2D distributions for rapid prototyping and visual evaluation. All are PyTorch Dataset objects.

    API Reference

  • CUDA


    CUDA-accelerated kernels and mixed precision support for performance-critical sampling and training.

    API Reference


Energy Landscapes

Gaussian Double Well Rastrigin
Gaussian Double Well Rastrigin
Rosenbrock Ackley Harmonic
Rosenbrock Ackley Harmonic

Synthetic Datasets

Gaussian Mixture Eight Gaussians Two Moons Swiss Roll
Gaussian Mixture Eight Gaussians Two Moons Swiss Roll
Checkerboard Pinwheel Circle Grid
Checkerboard Pinwheel Circle Grid

Quick Start

pip install torchebm

Define an energy model, create a sampler, and draw samples in a few lines:

1
2
3
4
5
6
7
8
9
import torch
from torchebm.core import GaussianModel
from torchebm.samplers import LangevinDynamics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GaussianModel(mean=torch.zeros(2), cov=torch.eye(2), device=device)

sampler = LangevinDynamics(model=model, step_size=0.01, device=device)
samples = sampler.sample(x=torch.randn(500, 2, device=device), n_steps=100)

See the tutorials and examples for training loops, flow-based generation, and more.

Community & Contribution

TorchEBM is open-source and developed with the research community in mind.


Citation

If TorchEBM is useful in your research, please cite it:

1
2
3
4
5
6
@misc{torchebm_library_2025,
  author       = {Ghaderi, Soran and Contributors},
  title        = {TorchEBM: A PyTorch Library for Training Energy-Based Models},
  year         = {2025},
  url          = {https://github.com/soran-ghaderi/torchebm},
}

License

MIT License. See the LICENSE file for details.