Skip to content

TorchEBM Logo

⚡ Energy-Based Modeling library for PyTorch, offering tools for 🔬 sampling, 🧠 inference, and 📊 learning in complex distributions.

  • Getting Started


    Start using TorchEBM in minutes with our quick installation and setup guide.

    Getting Started

  • Introduction


    Learn about Energy-Based Models and how TorchEBM can help you work with them.

    Introduction

  • Guides


    Explore in-depth guides for energy functions, samplers, and more.

    Guides

  • Examples


    Practical examples to help you apply TorchEBM to your projects.

    Examples

  • Blog


    Stay updated with the latest news, tutorials, and insights about TorchEBM.

    Blog

Quick Installation

pip install torchebm

Example Analytical Energy Landscapes

Toy Examples

These are some TorchEBM's built-in toy analytical energy landscapes for functionality and performance testing purposes.

Gaussian Energy

Gaussian Energy

\(E(x) = \frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu)\)

Double Well Energy Rastrigin Energy Rosenbrock Energy

from torchebm.core import GaussianEnergy
import torch

energy_fn = GaussianEnergy(
    mean=torch.zeros(2),
    cov=torch.eye(2)
)

Double Well Energy

Double Well Energy

\(E(x) = a(x^2 - b)^2\)

Gaussian Energy Rastrigin Energy Rosenbrock Energy

from torchebm.core import DoubleWellEnergy

energy_fn = DoubleWellEnergy(
    a=1.0, 
    b=2.0
)

Rastrigin Energy

Rastrigin Energy

\(E(x) = An + \sum_{i=1}^n \left[ x_i^2 - A\cos(2\pi x_i) \right]\)

Gaussian Energy Double Well Energy Rosenbrock Energy

from torchebm.core import RastriginEnergy

energy_fn = RastriginEnergy(
    A=10.0
)

Rosenbrock Energy

Rosenbrock Energy

\(E(x) = \sum_{i=1}^{n-1} \left[ a(x_{i+1} - x_i^2)^2 + (x_i - 1)^2 \right]\)

Gaussian Energy Double Well Energy Rastrigin Energy

from torchebm.core import RosenbrockEnergy

energy_fn = RosenbrockEnergy(
    a=1.0, 
    b=100.0
)

Quick Example

  • Create and Sample from Energy Models


    import torch
    from torchebm.core import GaussianEnergy
    from torchebm.samplers.langevin_dynamics import LangevinDynamics
    
    # Create an energy function
    energy_fn = GaussianEnergy(
        mean=torch.zeros(2),
        cov=torch.eye(2)
    )
    
    # Create a sampler
    sampler = LangevinDynamics(
        energy_function=energy_fn,
        step_size=0.01
    )
    
    # Generate samples
    samples = sampler.sample_chain(
        dim=2, n_steps=100, n_samples=1000
    )
    

Latest Release

TorchEBM is currently in early development. Check our GitHub repository for the latest updates and features.

Features & Roadmap

Our goal is to create a comprehensive library for energy-based modeling in PyTorch.

Status indicators:

  • ✅ - Completed
  • 🚧 - Work in progress
  • ⚠ - Needs improvement
  • ✨ - Planned feature

Core Infrastructure

  • CUDA-accelerated implementations ✅
  • Seamless integration with PyTorch ✅
  • Energy function base class ✅
  • Sampler base class ✅
  • Gaussian ✅
  • Double well ✅
  • Rastrigin ✅
  • Rosenbrock ✅
  • Ackley ✅
  • Langevin Dynamics ✅
  • Hamiltonian Monte Carlo (HMC) 🚧
  • Metropolis-Hastings ⚠
  • Denoising Diffusion Probabilistic Models (DDPM) ✨
  • Denoising Diffusion Implicit Models (DDIM) ✨
  • Generalized Gaussian Diffusion Models (GGDM) ✨
  • Differentiable Diffusion Sampler Search (DDSS) ✨
  • Euler Method ✨
  • Heun's Method ✨
  • PLMS (Pseudo Likelihood Multistep) ✨
  • DPM (Diffusion Probabilistic Models) ✨
  • Gibbs Sampling ✨
  • No-U-Turn Sampler (NUTS) ✨
  • Slice Sampling ✨
  • Reversible Jump MCMC ✨
  • Particle Filters (Sequential Monte Carlo) ✨
  • Adaptive Metropolis ✨
  • Parallel Tempering (Replica Exchange) ✨
  • Stochastic Gradient Langevin Dynamics (SGLD) ✨
  • Stein Variational Gradient Descent (SVGD) ✨
  • Metropolis-Adjusted Langevin Algorithm (MALA) ✨
  • Unadjusted Langevin Algorithm (ULA) ✨
  • Bouncy Particle Sampler ✨
  • Zigzag Sampler ✨
  • Annealed Importance Sampling (AIS) ✨
  • Sequential Monte Carlo (SMC) Samplers ✨
  • Elliptical Slice Sampling ✨
  • Contrastive Divergence Methods 🚧
  • Contrastive Divergence (CD-k) 🚧
  • Persistent Contrastive Divergence (PCD) ✨
  • Fast Persistent Contrastive Divergence (FPCD) ✨
  • Parallel Tempering Contrastive Divergence (PTCD) ✨
  • Score Matching Techniques ✨
  • Standard Score Matching ✨
  • Denoising Score Matching ✨
  • Sliced Score Matching ✨
  • Maximum Likelihood Estimation (MLE) ✨
  • Margin Loss ✨
  • Noise Contrastive Estimation (NCE) ✨
  • Ratio Matching ✨
  • Minimum Probability Flow ✨
  • Adversarial Training Loss ✨
  • Kullback-Leibler (KL) Divergence Loss ✨
  • Fisher Divergence ✨
  • Hinge Embedding Loss ✨
  • Cross-Entropy Loss (for discrete outputs) ✨
  • Mean Squared Error (MSE) Loss (for continuous outputs) ✨
  • Improved Contrastive Divergence Loss ✨
  • Testing Framework 🚧
  • Visualization Tools 🚧
  • Performance Benchmarking ✨
  • Neural Network Integration ✨
  • Hyperparameter Optimization ✨
  • Distribution Diagnostics ✨

License

TorchEBM is released under the MIT License, which is a permissive license that allows for reuse with few restrictions.

Contributing

We welcome contributions! If you're interested in improving TorchEBM or adding new features, please check our contributing guidelines.

Our project follows specific commit message conventions to maintain a clear project history and generate meaningful changelogs.