Sampling Algorithms¶
Sampling from energy-based models is a core task in TorchEBM. This guide explains the different sampling algorithms available and how to use them effectively.
Overview of Sampling¶
In energy-based models, we need to sample from the probability distribution defined by the energy function:
Since the normalizing constant Z is typically intractable, we use Markov Chain Monte Carlo (MCMC) methods to generate samples without needing to compute Z.
Available Samplers¶
Langevin Dynamics¶
Langevin Dynamics is a gradient-based MCMC method that updates samples using the energy gradient plus Gaussian noise:
from torchebm.samplers.langevin_dynamics import LangevinDynamics
from torchebm.core import GaussianEnergy
import torch
# Create an energy function
energy_fn = GaussianEnergy(
mean=torch.zeros(10),
cov=torch.eye(10)
)
# Create a Langevin dynamics sampler
langevin_sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.01,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Generate samples
samples = langevin_sampler.sample_chain(
dim=10,
n_steps=1000,
n_samples=100,
return_trajectory=False
)
Parameters¶
energy_function
: The energy function to sample fromstep_size
: Step size for gradient updates (controls exploration vs. stability)noise_scale
: Scale of the noise (default is sqrt(2*step_size))device
: The device to perform sampling on (e.g., "cuda" or "cpu")
Hamiltonian Monte Carlo (HMC)¶
HMC uses Hamiltonian dynamics to make more efficient proposals, leading to better exploration of the distribution:
from torchebm.samplers.mcmc import HamiltonianMonteCarlo
from torchebm.core import DoubleWellEnergy
import torch
# Create an energy function
energy_fn = DoubleWellEnergy()
# Create an HMC sampler
hmc_sampler = HamiltonianMonteCarlo(
energy_function=energy_fn,
step_size=0.1,
n_leapfrog_steps=10,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Generate samples
samples = hmc_sampler.sample_chain(
dim=2,
n_steps=500,
n_samples=100,
return_trajectory=False
)
Parameters¶
energy_function
: The energy function to sample fromstep_size
: Step size for leapfrog integrationn_leapfrog_steps
: Number of leapfrog steps per iterationdevice
: The device to perform sampling on
Advanced Sampling Usage¶
Tracking Sampling Progress¶
You can track diagnostics during sampling by setting return_diagnostics=True
:
samples, diagnostics = sampler.sample_chain(
dim=10,
n_steps=1000,
n_samples=100,
return_trajectory=True,
return_diagnostics=True
)
# Diagnostics shape: [n_steps, n_diagnostics, n_samples, dim]
# Includes: Mean, Variance, Energy, Acceptance rate (for HMC)
Custom Initialization¶
You can start the sampling chain from a specific point:
# Custom initialization
x_init = torch.randn(100, 10) # [n_samples, dim]
samples = sampler.sample_chain(
x=x_init,
n_steps=1000,
return_trajectory=False
)
Burn-in and Thinning¶
For better samples, you can implement burn-in and thinning:
# Perform burn-in and thinning manually
samples, trajectory = sampler.sample_chain(
dim=10,
n_steps=2000,
n_samples=100,
return_trajectory=True
)
# Discard the first 1000 steps (burn-in)
# Keep every 10th sample (thinning)
thinned_samples = trajectory[:, 1000::10, :]
Choosing a Sampler¶
- Langevin Dynamics: Good for general-purpose sampling, especially in high dimensions
- Hamiltonian Monte Carlo: Better exploration of complex energy landscapes, but more expensive per step
Sampler Performance Tips¶
- Adjust step size: Too large → unstable sampling; too small → slow mixing
- Use GPU acceleration: For large batches of samples or high dimensions
- Monitor acceptance rates: For HMC, aim for 60-90% acceptance rate
- Check sample quality: Correlation between successive samples should be low
- Burn-in: Discard initial samples before the chain reaches its stationary distribution
Implementing Custom Samplers¶
You can create custom samplers by subclassing BaseSampler
:
from torchebm.core import BaseSampler
import torch
class MyCustomSampler(BaseSampler):
def __init__(self, energy_function, param1, param2, device="cpu"):
super().__init__(energy_function, device)
self.param1 = param1
self.param2 = param2
def step(self, x, step_idx=None):
# Implement your sampling step here
# x shape: [n_samples, dim]
# Example: simple random walk
noise = torch.randn_like(x) * self.param1
x_new = x + noise
# Return updated samples and any diagnostics
return x_new, {"diagnostic1": value1, "diagnostic2": value2}