Training Energy-Based Models¶
This guide covers the fundamental techniques for training energy-based models (EBMs) using TorchEBM. We'll explore various training methods, loss functions, and optimization strategies to help you effectively train your models.
Overview¶
Training energy-based models involves estimating the parameters of an energy function such that the corresponding probability distribution matches a target data distribution. Unlike in traditional supervised learning, this is often an unsupervised task where the goal is to learn the underlying structure of the data.
The training process typically involves:
- Defining an energy function (parameterized by a neural network or analytical form)
- Choosing a training method and loss function
- Optimizing the energy function parameters
- Evaluating the model using sampling and visualization techniques
Methods for Training EBMs¶
TorchEBM supports several methods for training EBMs, each with their own advantages and trade-offs:
Contrastive Divergence (CD)¶
Contrastive Divergence is one of the most widely used methods for training EBMs. It approximates the gradient of the log-likelihood by comparing data samples to model samples obtained through short MCMC runs.
import torch
from torchebm.core import BaseEnergyFunction, MLPEnergyFunction
from torchebm.samplers.langevin_dynamics import LangevinDynamics
# Create energy function
energy_fn = MLPEnergyFunction(input_dim=2, hidden_dim=64)
# Create sampler for generating negative samples
sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.01
)
# Optimizer
optimizer = torch.optim.Adam(energy_fn.parameters(), lr=0.001)
# Contrastive Divergence training step
def train_step_cd(data_batch, k_steps=10):
optimizer.zero_grad()
# Positive phase: compute energy of real data
pos_energy = energy_fn(data_batch)
# Negative phase: generate samples from current model
# Start from random noise
neg_samples = torch.randn_like(data_batch)
# Sample for n_steps steps
neg_samples = sampler.sample_chain(
initial_points=neg_samples,
n_steps=k_steps,
return_final=True
)
# Compute energy of generated samples
neg_energy = energy_fn(neg_samples)
# Compute loss
# Minimize energy of real data, maximize energy of generated samples
loss = pos_energy.mean() - neg_energy.mean()
# Backpropagation
loss.backward()
optimizer.step()
return loss.item(), pos_energy.mean().item(), neg_energy.mean().item()
Persistent Contrastive Divergence (PCD)¶
PCD improves on standard CD by maintaining a persistent chain of samples that are used across training iterations:
# Initialize persistent samples
persistent_samples = torch.randn(1000, 2) # Shape: [n_persistent, data_dim]
# PCD training step
def train_step_pcd(data_batch, k_steps=10):
global persistent_samples
optimizer.zero_grad()
# Positive phase: compute energy of real data
pos_energy = energy_fn(data_batch)
# Negative phase: continue sampling from persistent chains
# Start from existing persistent samples (select a random subset)
indices = torch.randperm(persistent_samples.shape[0])[:data_batch.shape[0]]
initial_samples = persistent_samples[indices].clone()
# Sample for n_steps steps
neg_samples = sampler.sample_chain(
initial_points=initial_samples,
n_steps=k_steps,
return_final=True
)
# Update persistent samples
persistent_samples[indices] = neg_samples.detach()
# Compute energy of generated samples
neg_energy = energy_fn(neg_samples)
# Compute loss
loss = pos_energy.mean() - neg_energy.mean()
# Backpropagation
loss.backward()
optimizer.step()
return loss.item(), pos_energy.mean().item(), neg_energy.mean().item()
Score Matching¶
Score Matching avoids the need for MCMC sampling altogether, focusing on matching the gradient of the log-density (score function) instead:
def train_step_score_matching(data_batch, noise_scale=0.01):
optimizer.zero_grad()
# Ensure data requires gradients
data_batch.requires_grad_(True)
# Compute energy
energy = energy_fn(data_batch)
# Compute gradients with respect to inputs
grad_energy = torch.autograd.grad(
outputs=energy.sum(),
inputs=data_batch,
create_graph=True
)[0]
# Compute score matching loss
loss = 0.5 * (grad_energy.pow(2)).sum(dim=1).mean()
# Add regularization (optional)
noise_data = data_batch + noise_scale * torch.randn_like(data_batch)
noise_energy = energy_fn(noise_data)
reg_loss = ((noise_energy - energy) ** 2).mean()
total_loss = loss + 0.1 * reg_loss
# Backpropagation
total_loss.backward()
optimizer.step()
return total_loss.item()
Denoising Score Matching¶
Denoising Score Matching is a variant that adds noise to data points and trains the model to predict the score of the noised data:
def train_step_denoising_score_matching(data_batch, sigma=0.1):
optimizer.zero_grad()
# Add noise to data
noise = sigma * torch.randn_like(data_batch)
noised_data = data_batch + noise
noised_data.requires_grad_(True)
# Compute energy of noised data
energy = energy_fn(noised_data)
# Compute gradients with respect to noised inputs
grad_energy = torch.autograd.grad(
outputs=energy.sum(),
inputs=noised_data,
create_graph=True
)[0]
# Ground truth score is -noise/sigma^2 for Gaussian noise
target_score = -noise / (sigma**2)
# Compute loss as MSE between predicted and target scores
loss = 0.5 * ((grad_energy - target_score) ** 2).sum(dim=1).mean()
# Backpropagation
loss.backward()
optimizer.step()
return loss.item()
Full Training Loop Example¶
Here's a complete example showing how to train an EBM using contrastive divergence:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torchebm.core import BaseEnergyFunction
from torchebm.samplers.langevin_dynamics import LangevinDynamics
# Define a neural network energy function
class MLPEnergyFunction(BaseEnergyFunction):
def __init__(self, input_dim=2, hidden_dim=64):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.network(x).squeeze(-1)
# Create a synthetic dataset (mixture of Gaussians)
def generate_data(n_samples=10000):
# Create a mixture of 5 Gaussians
centers = [
(0, 0),
(2, 2),
(-2, 2),
(2, -2),
(-2, -2)
]
std = 0.3
# Equal number of samples per component
n_per_component = n_samples // len(centers)
data = []
for center in centers:
component_data = torch.randn(n_per_component, 2) * std
component_data[:, 0] += center[0]
component_data[:, 1] += center[1]
data.append(component_data)
# Combine all components
data = torch.cat(data, dim=0)
# Shuffle
indices = torch.randperm(data.shape[0])
data = data[indices]
return data
# Generate dataset
dataset = generate_data(10000)
# Create energy function and sampler
energy_fn = MLPEnergyFunction(input_dim=2, hidden_dim=128)
sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.01
)
# Optimizer
optimizer = torch.optim.Adam(energy_fn.parameters(), lr=0.0001)
# Training hyperparameters
batch_size = 128
n_epochs = 200
k_steps = 10 # Number of MCMC steps for negative samples
log_interval = 20
# Persistent samples
persistent_samples = torch.randn(1000, 2)
# Training loop
for epoch in range(n_epochs):
# Shuffle dataset
indices = torch.randperm(dataset.shape[0])
dataset = dataset[indices]
total_loss = 0
n_batches = 0
for i in range(0, dataset.shape[0], batch_size):
# Get batch
batch_indices = indices[i:i + batch_size]
batch = dataset[batch_indices]
# Training step using PCD
optimizer.zero_grad()
# Positive phase
pos_energy = energy_fn(batch)
# Negative phase with persistent samples
pcd_indices = torch.randperm(persistent_samples.shape[0])[:batch_size]
neg_samples_init = persistent_samples[pcd_indices].clone()
neg_samples = sampler.sample_chain(
initial_points=neg_samples_init,
n_steps=k_steps,
return_final=True
)
# Update persistent samples
persistent_samples[pcd_indices] = neg_samples.detach()
# Compute energy of negative samples
neg_energy = energy_fn(neg_samples)
# Compute loss
loss = pos_energy.mean() - neg_energy.mean()
# Backpropagation
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
# Log progress
avg_loss = total_loss / n_batches
if (epoch + 1) % log_interval == 0:
print(f'Epoch {epoch + 1}/{n_epochs}, BaseLoss: {avg_loss:.4f}')
# Visualize current energy function
with torch.no_grad():
# Create grid
x = torch.linspace(-4, 4, 100)
y = torch.linspace(-4, 4, 100)
X, Y = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack([X.flatten(), Y.flatten()], dim=1)
# Compute energy values
energies = energy_fn(grid_points).reshape(100, 100)
# Visualize
plt.figure(figsize=(10, 8))
plt.contourf(X.numpy(), Y.numpy(), energies.numpy(), 50, cmap='viridis')
# Plot data points
plt.scatter(dataset[:500, 0], dataset[:500, 1], s=1, color='red', alpha=0.5)
# Plot samples from model
sampled = sampler.sample_chain(
dim=2,
n_samples=500,
n_steps=500,
burn_in=100
)
plt.scatter(sampled[:, 0], sampled[:, 1], s=1, color='white', alpha=0.5)
plt.title(f'Energy Landscape (Epoch {epoch + 1})')
plt.xlim(-4, 4)
plt.ylim(-4, 4)
plt.colorbar(label='Energy')
plt.tight_layout()
plt.show()
# Final evaluation
print("Training complete. Generating samples...")
# Generate final samples
final_samples = sampler.sample_chain(
dim=2,
n_samples=5000,
n_steps=1000,
burn_in=200
)
# Visualize final distribution
plt.figure(figsize=(12, 5))
# Plot data distribution
plt.subplot(1, 2, 1)
plt.hist2d(dataset[:, 0].numpy(), dataset[:, 1].numpy(), bins=50, cmap='Blues')
plt.title('Data Distribution')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar(label='Count')
# Plot model distribution
plt.subplot(1, 2, 2)
plt.hist2d(final_samples[:, 0].numpy(), final_samples[:, 1].numpy(), bins=50, cmap='Reds')
plt.title('Model Distribution')
plt.xlabel('x')
plt.ylabel('y')
plt.colorbar(label='Count')
plt.tight_layout()
plt.show()
Training with Regularization¶
Adding regularization can help stabilize training and prevent the energy function from assigning extremely low values to certain regions:
def train_step_with_regularization(data_batch, k_steps=10, l2_reg=0.001):
optimizer.zero_grad()
# Positive phase
pos_energy = energy_fn(data_batch)
# Negative phase
neg_samples = torch.randn_like(data_batch)
neg_samples = sampler.sample_chain(
initial_points=neg_samples,
n_steps=k_steps,
return_final=True
)
neg_energy = energy_fn(neg_samples)
# Contrastive divergence loss
cd_loss = pos_energy.mean() - neg_energy.mean()
# L2 regularization on parameters
l2_norm = sum(p.pow(2).sum() for p in energy_fn.parameters())
# Total loss
loss = cd_loss + l2_reg * l2_norm
# Backpropagation
loss.backward()
optimizer.step()
return loss.item()
Training with Gradient Clipping¶
Gradient clipping can prevent explosive gradients during training:
def train_step_with_gradient_clipping(data_batch, k_steps=10, max_norm=1.0):
optimizer.zero_grad()
# Positive phase
pos_energy = energy_fn(data_batch)
# Negative phase
neg_samples = torch.randn_like(data_batch)
neg_samples = sampler.sample_chain(
initial_points=neg_samples,
n_steps=k_steps,
return_final=True
)
neg_energy = energy_fn(neg_samples)
# Contrastive divergence loss
loss = pos_energy.mean() - neg_energy.mean()
# Backpropagation
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(energy_fn.parameters(), max_norm)
optimizer.step()
return loss.item()
Monitoring Training Progress¶
It's important to monitor various metrics during training to ensure the model is learning effectively:
def visualize_training_progress(energy_fn, data, sampler, epoch):
# Generate samples from current model
samples = sampler.sample_chain(
dim=data.shape[1],
n_samples=1000,
n_steps=500,
burn_in=100
)
# Compute energy statistics
with torch.no_grad():
data_energy = energy_fn(data[:1000]).mean().item()
sample_energy = energy_fn(samples).mean().item()
print(f"Epoch {epoch}: Data Energy: {data_energy:.4f}, Sample Energy: {sample_energy:.4f}")
# Create visualization
plt.figure(figsize=(15, 5))
# Plot data
plt.subplot(1, 3, 1)
plt.scatter(data[:1000, 0].numpy(), data[:1000, 1].numpy(), s=2, alpha=0.5)
plt.title('Data')
plt.xlim(-4, 4)
plt.ylim(-4, 4)
# Plot samples
plt.subplot(1, 3, 2)
plt.scatter(samples[:, 0].numpy(), samples[:, 1].numpy(), s=2, alpha=0.5)
plt.title('Samples')
plt.xlim(-4, 4)
plt.ylim(-4, 4)
# Plot energy landscape
plt.subplot(1, 3, 3)
x = torch.linspace(-4, 4, 100)
y = torch.linspace(-4, 4, 100)
X, Y = torch.meshgrid(x, y, indexing='ij')
grid_points = torch.stack([X.flatten(), Y.flatten()], dim=1)
with torch.no_grad():
energies = energy_fn(grid_points).reshape(100, 100)
plt.contourf(X.numpy(), Y.numpy(), energies.numpy(), 50, cmap='viridis')
plt.colorbar(label='Energy')
plt.title('Energy Landscape')
plt.tight_layout()
plt.savefig(f"training_progress_epoch_{epoch}.png")
plt.close()
Debugging Tips¶
Training EBMs can be challenging. Here are some tips for debugging:
- Start with simpler energy functions: Begin with analytical energy functions before moving to neural networks
- Monitor energy values: If energy values become extremely large or small, the model may be collapsing
- Visualize samples: Regularly visualize samples during training to check if they match the data distribution
- Adjust learning rate: Try different learning rates; EBMs often require smaller learning rates
- Increase MCMC steps: More MCMC steps for negative samples can improve training stability
- Add noise regularization: Adding small noise to data samples can help prevent overfitting
- Use gradient clipping: Clip gradients to prevent instability during training
- Try different initializations: Initial parameter values can significantly impact training dynamics
Training on Different Data Types¶
The training approach may vary depending on the type of data:
Images¶
For image data, convolutional architectures are recommended:
class ImageEBM(BaseEnergyFunction):
def __init__(self, input_channels=1, image_size=28):
super().__init__()
self.conv_net = nn.Sequential(
nn.Conv2d(input_channels, 32, 3, 1, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 4, 2, 1), # Downsampling
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1), # Downsampling
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1), # Downsampling
nn.LeakyReLU(0.2)
)
# Calculate final feature map size
feature_size = 256 * (image_size // 8) * (image_size // 8)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(feature_size, 1)
)
def forward(self, x):
# Ensure proper dimensions for convolutional layers
if x.ndim == 3:
x = x.unsqueeze(1) # Add channel dimension for grayscale
features = self.conv_net(x)
energy = self.fc(features).squeeze(-1)
return energy
Sequential Data¶
For sequential data, recurrent or transformer architectures may be more appropriate:
class SequentialEBM(BaseEnergyFunction):
def __init__(self, input_dim=1, hidden_dim=128, seq_len=50):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_dim,
hidden_size=hidden_dim,
num_layers=2,
batch_first=True,
bidirectional=True
)
self.fc = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
# x shape: [batch_size, seq_len, input_dim]
lstm_out, _ = self.lstm(x)
# Use final hidden state
final_hidden = lstm_out[:, -1, :]
# Compute energy
energy = self.fc(final_hidden).squeeze(-1)
return energy
Conclusion¶
Training energy-based models is a challenging but rewarding process. By leveraging the techniques outlined in this guide, you can effectively train EBMs using TorchEBM for a variety of applications. Remember to monitor training progress, visualize results, and adjust your approach based on the specific characteristics of your data and modeling objectives.
Whether you're using contrastive divergence, score matching, or other methods, the key is to ensure that your energy function accurately captures the underlying structure of your data distribution. With practice and experimentation, you can master the art of training energy-based models for complex tasks in unsupervised learning.