Soran Ghaderi

Looking for PhD & Research Opportunities - MSc in AI at the University of Essex

TorchEBM



TorchEBM

TorchEBM

⚡ Energy-Based Modeling library for PyTorch, offering tools for 🔬 sampling, 🧠 inference, and 📊 learning in complex distributions.</b>

GitHub Repo stars PyPI - Downloads latest release

About

TorchEBM is a CUDA-accelerated parallel library for Energy-Based Models (EBMs) built on PyTorch. It provides efficient implementations of sampling, inference, and learning algorithms for EBMs, with a focus on scalability and performance. This is an early version and is under development.

Installation

$ pip install torchebm

Quick Start

import torch
from torchebm import EnergyFunction, LangevinDynamics
import matplotlib.pyplot as plt

# You can define your energy function like the following. However you don't need to implement the gradient and it is already automated, but for the sake of the example, we'll include it.
class QuadraticEnergy(EnergyFunction):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return 0.5 * torch.sum(x**2, dim=-1)
    
    def gradient(self, x: torch.Tensor) -> torch.Tensor:
        return x

# Instantiate the energy function and the sampler
energy_fn = QuadraticEnergy()
sampler = LangevinDynamics(energy_fn, step_size=0.1, noise_scale=0.1)

# Generate samples
initial_state = torch.tensor([2.0, 2.0])
samples = sampler.sample_chain(initial_state, n_steps=1000, n_samples=500)

# A Single trajectory
trajectory = sampler.sample(initial_state, n_steps, return_trajectory=True)

# Demonstrate parallel sampling
n_chains = 10
initial_states = torch.randn(n_chains, 2) * 2
parallel_samples = sampler.sample_parallel(initial_states, n_steps=1000)

Example output

For the visualization codes, check out the examples directory

torchebm.png

Features (so far)

  • Samplers:
    • Langevin dynamics
    • Hamiltonian Monte Carlo
  • Losses
    • Contrastive Divergence
    • Persistent Contrastive Divergence
    • Parallel Tempering CD
    • Score matching
  • Mathematical energy functions
    • Double-Well Energy
    • Gaussian Energy
    • Harmonic Energy
    • Rosenbrock Energy
    • Ackley Energy
    • Rastrigin Energy
  • CUDA-accelerated implementations
  • Seamless integration with PyTorch
  • Standard and scalable architecture and API
  • Examples

Roadmap

A list of samplers we aim to implement. This list is subject to change as we progress and get more familiar with the most important ones.

  • Denoising Diffusion Probabilistic Models (DDPM) A foundational sampler that uses explicit probabilistic models to remove noise from images. Learn more
  • Denoising Diffusion Implicit Models (DDIM) An implicit model that offers faster sampling with minimal loss of quality compared to DDPM. Learn more
  • Generalized Gaussian Diffusion Models (GGDM) A family of flexible non-Markovian samplers for diffusion models that optimize sample quality through gradient descent.
  • Differentiable Diffusion Sampler Search (DDSS) A method for optimizing fast samplers for pre-trained diffusion models by differentiating through sample quality scores.
  • Euler Method A numerical approach for sampling in diffusion processes, providing a basic framework for iterative refinement.
  • Heun’s Method A second-order numerical method that improves accuracy over Euler’s method by using linear and nonlinear approximations.
  • PLMS (Pseudo Likelihood Multistep) A method that balances speed and quality in the sampling process by utilizing pseudo-likelihood estimates.
  • DPM (Diffusion Probabilistic Models) A family of probabilistic models that utilize diffusion processes for generating high-quality samples.
  • Metropolis-Hastings #7 Classic MCMC algorithm for sampling from probability distributions
  • Gibbs Sampling MCMC algorithm for obtaining a sequence of observations from a specified multivariate probability distribution
  • No-U-Turn Sampler (NUTS) Extension of HMC that eliminates the need to set a number of steps
  • Slice Sampling MCMC method which samples from a distribution by drawing uniformly from a region under the plot of its density function
  • Reversible Jump MCMC Variant of MCMC for sampling from spaces of varying dimensions
  • Particle Filters (Sequential Monte Carlo) Method for implementing recursive Bayesian filters using Monte Carlo simulations
  • Adaptive Metropolis Metropolis-Hastings algorithm with an adaptive proposal distribution
  • Parallel Tempering (Replica Exchange) Method to improve the mixing of MCMC by simulating multiple chains at different temperatures
  • Stochastic Gradient Langevin Dynamics (SGLD) Combines stochastic gradient descent with Langevin dynamics for large-scale Bayesian learning
  • Stein Variational Gradient Descent (SVGD) Deterministic sampling method that combines the advantages of MCMC and variational inference
  • Metropolis-Adjusted Langevin Algorithm (MALA) Combines Langevin dynamics with Metropolis-Hastings correction step
  • Unadjusted Langevin Algorithm (ULA) Discretization of the Langevin diffusion without Metropolis-Hastings step
  • Bouncy Particle Sampler Continuous-time MCMC method based on piecewise deterministic Markov processes
  • Zigzag Sampler Another continuous-time MCMC method using piecewise deterministic Markov processes
  • Annealed Importance Sampling (AIS) Combines importance sampling and MCMC to efficiently sample from complex distributions
  • Sequential Monte Carlo (SMC) Samplers Generalizes particle filters to static problems
  • Elliptical Slice Sampling Auxiliary variable slice sampling method for multivariate Gaussians
  • Affine Invariant Ensemble Sampler (Emcee) Ensemble sampling method that adapts to the covariance structure of the target distribution
  • Riemann Manifold Hamiltonian Monte Carlo HMC on Riemann manifold to automatically adapt to the local geometry of the target distribution

We’ll update losses as well.

Contributing

Contributions are welcome! Please check the issues page for current tasks or create a new issue to discuss proposed changes.

License

PyPI license This project is licensed under the MIT License - see the LICENSE file for details.

comments powered by Disqus