Skip to content

Training Energy-Based Models

This guide covers the fundamental techniques for training energy-based models (EBMs) using TorchEBM. We'll explore various training methods, loss functions, and optimization strategies to help you effectively train your models.

Overview

Training energy-based models involves estimating the parameters of an energy function such that the corresponding probability distribution matches a target data distribution. Unlike in traditional supervised learning, this is often an unsupervised task where the goal is to learn the underlying structure of the data.

The training process typically involves:

  1. Defining an energy function (parameterized by a neural network or analytical form)
  2. Choosing a training method and loss function
  3. Optimizing the energy function parameters
  4. Evaluating the model using sampling and visualization techniques

Defining an Energy Function

In TorchEBM, you can create custom energy functions by subclassing BaseEnergyFunction:

import torch
import torch.nn as nn
from torchebm.core import BaseEnergyFunction

class MLPEnergy(BaseEnergyFunction):
    """A simple MLP to act as the energy function."""

    def __init__(self, input_dim: int, hidden_dim: int = 64):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SELU(),
            nn.Linear(hidden_dim, 1),
            nn.Tanh(),  # Optional: can help stabilize training
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze(-1)

Training with Contrastive Divergence

Contrastive Divergence (CD) is one of the most common methods for training EBMs. Here's a complete example of training with CD using TorchEBM:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import os

from torchebm.core import BaseEnergyFunction, CosineScheduler
from torchebm.samplers import LangevinDynamics
from torchebm.losses import ContrastiveDivergence
from torchebm.datasets import TwoMoonsDataset

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Create output directory for plots
os.makedirs("training_plots", exist_ok=True)

# Hyperparameters
INPUT_DIM = 2
HIDDEN_DIM = 16
BATCH_SIZE = 256
EPOCHS = 200
LEARNING_RATE = 1e-3

# Use dynamic schedulers for sampler parameters
SAMPLER_STEP_SIZE = CosineScheduler(start_value=3e-2, end_value=5e-3, n_steps=100)
SAMPLER_NOISE_SCALE = CosineScheduler(start_value=3e-1, end_value=1e-2, n_steps=100)

CD_K = 10  # Number of MCMC steps for Contrastive Divergence
USE_PCD = True  # Whether to use Persistent Contrastive Divergence
VISUALIZE_EVERY = 20

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load dataset
dataset = TwoMoonsDataset(n_samples=3000, noise=0.05, seed=42, device=device)
real_data_for_plotting = dataset.get_data()
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
)

# Setup model components
energy_model = MLPEnergy(INPUT_DIM, HIDDEN_DIM).to(device)
sampler = LangevinDynamics(
    energy_function=energy_model,
    step_size=SAMPLER_STEP_SIZE,
    noise_scale=SAMPLER_NOISE_SCALE,
    device=device,
)
loss_fn = ContrastiveDivergence(
    energy_function=energy_model,
    sampler=sampler,
    k_steps=CD_K,
    persistent=USE_PCD,
    buffer_size=BATCH_SIZE,
).to(device)

# Optimizer
optimizer = optim.Adam(energy_model.parameters(), lr=LEARNING_RATE)

# Training loop
losses = []
print("Starting training...")
for epoch in range(EPOCHS):
    energy_model.train()
    epoch_loss = 0.0

    for i, data_batch in enumerate(dataloader):
        # Zero gradients
        optimizer.zero_grad()

        # Calculate Contrastive Divergence loss
        loss, negative_samples = loss_fn(data_batch)

        # Backpropagate and optimize
        loss.backward()

        # Optional: Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(energy_model.parameters(), max_norm=1.0)

        # Update parameters
        optimizer.step()

        epoch_loss += loss.item()

    # Calculate average loss for this epoch
    avg_epoch_loss = epoch_loss / len(dataloader)
    losses.append(avg_epoch_loss)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Average Loss: {avg_epoch_loss:.4f}")

    # Visualize progress
    if (epoch + 1) % VISUALIZE_EVERY == 0 or epoch == 0:
        print("Generating visualization...")
        plot_energy_and_samples(
            energy_fn=energy_model,
            real_samples=real_data_for_plotting,
            sampler=sampler,
            epoch=epoch + 1,
            device=device,
            plot_range=2.5,
            k_sampling=200,
        )

# Plot the training loss
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.savefig('docs/assets/images/training/cd_training_loss.png')
plt.show()

Visualization During Training

It's important to visualize the model's progress during training. Here's a helper function to plot the energy landscape and samples:

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchebm.core import BaseEnergyFunction
from torchebm.samplers import LangevinDynamics

@torch.no_grad()
def plot_energy_and_samples(
    energy_fn: BaseEnergyFunction,
    real_samples: torch.Tensor,
    sampler: LangevinDynamics,
    epoch: int,
    device: torch.device,
    grid_size: int = 100,
    plot_range: float = 3.0,
    k_sampling: int = 100,
):
    """Plots the energy surface, real data, and model samples."""
    plt.figure(figsize=(8, 8))

    # Create grid for energy surface plot
    x_coords = torch.linspace(-plot_range, plot_range, grid_size, device=device)
    y_coords = torch.linspace(-plot_range, plot_range, grid_size, device=device)
    xv, yv = torch.meshgrid(x_coords, y_coords, indexing="xy")
    grid = torch.stack([xv.flatten(), yv.flatten()], dim=1)

    # Calculate energy on the grid
    energy_fn.eval()
    energy_values = energy_fn(grid).cpu().numpy().reshape(grid_size, grid_size)

    # Plot energy surface (using probability density for better visualization)
    log_prob_values = -energy_values
    log_prob_values = log_prob_values - np.max(log_prob_values)
    prob_density = np.exp(log_prob_values)

    plt.contourf(
        xv.cpu().numpy(),
        yv.cpu().numpy(),
        prob_density,
        levels=50,
        cmap="viridis",
    )
    plt.colorbar(label="exp(-Energy) (unnormalized density)")

    # Generate samples from the current model for visualization
    vis_start_noise = torch.randn(
        500, real_samples.shape[1], device=device
    )
    model_samples_tensor = sampler.sample(x=vis_start_noise, n_steps=k_sampling)
    model_samples = model_samples_tensor.cpu().numpy()

    # Plot real and model samples
    real_samples_np = real_samples.cpu().numpy()
    plt.scatter(
        real_samples_np[:, 0],
        real_samples_np[:, 1],
        s=10,
        alpha=0.5,
        label="Real Data",
        c="white",
        edgecolors="k",
        linewidths=0.5,
    )
    plt.scatter(
        model_samples[:, 0],
        model_samples[:, 1],
        s=10,
        alpha=0.5,
        label="Model Samples",
        c="red",
        edgecolors="darkred",
        linewidths=0.5,
    )

    plt.xlim(-plot_range, plot_range)
    plt.ylim(-plot_range, plot_range)
    plt.title(f"Epoch {epoch}")
    plt.xlabel("X1")
    plt.ylabel("X2")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(f"docs/assets/images/training/ebm_training_epoch_{epoch}.png")
    plt.close()

Training with Score Matching

An alternative to Contrastive Divergence is Score Matching, which doesn't require MCMC sampling:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchebm.core import BaseEnergyFunction
from torchebm.losses import ScoreMatching
from torchebm.datasets import GaussianMixtureDataset

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Setup model, loss, and optimizer
energy_fn = MLPEnergy(input_dim=2).to(device)
sm_loss_fn = ScoreMatching(
    energy_function=energy_fn,
    hessian_method="hutchinson",  # More efficient for higher dimensions
    hutchinson_samples=5,
    device=device,
)
optimizer = optim.Adam(energy_fn.parameters(), lr=0.001)

# Setup data
dataset = GaussianMixtureDataset(
    n_samples=500, n_components=4, std=0.1, seed=123
).get_data()
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Training Loop
losses = []
for epoch in range(50):
    epoch_loss = 0.0
    for batch_data in dataloader:
        batch_data = batch_data.to(device)

        optimizer.zero_grad()
        loss = sm_loss_fn(batch_data)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(dataloader)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/50, Loss: {avg_loss:.6f}")

# Plot the training loss
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Score Matching Training Loss')
plt.grid(True, alpha=0.3)
plt.savefig('docs/assets/images/training/sm_training_loss.png')
plt.show()

Comparing Training Methods

Here's how the major training methods for EBMs compare:

Method Pros Cons Best For
Contrastive Divergence (CD) - Simple to implement
- Computationally efficient
- Works well for simple distributions
- May not converge to true gradient
- Limited mode exploration with short MCMC runs
- Can lead to poor samples
Restricted Boltzmann Machines, simpler energy-based models
Persistent CD (PCD) - Better mode exploration than CD
- More accurate gradient estimation
- Improved sample quality
- Requires maintaining persistent chains
- Can be unstable with high learning rates
- Chains can get stuck in metastable states
Deep Boltzmann Machines, models with complex energy landscapes
Score Matching - Avoids MCMC sampling
- Consistent estimator
- Stable optimization
- Requires computing Hessian diagonals
- High computational cost in high dimensions
- Need for second derivatives
Continuous data, models with tractable derivatives
Denoising Score Matching - Avoids explicit Hessian computation
- More efficient than standard score matching
- Works well for high-dimensional data
- Performance depends on noise distribution
- Trade-off between noise level and estimation accuracy
- May smooth out important details
Image modeling, high-dimensional continuous distributions
Sliced Score Matching - Linear computational complexity
- No Hessian computation needed
- Scales well to high dimensions
- Approximation depends on number of projections
- Less accurate with too few random projections
- Still requires gradient computation
High-dimensional problems where other score matching variants are too expensive

Advanced Training Techniques

Gradient Clipping

Gradient clipping is essential for stable EBM training:

1
2
3
# After loss.backward()
torch.nn.utils.clip_grad_norm_(energy_model.parameters(), max_norm=1.0)
optimizer.step()

Regularization Techniques

Adding regularization can help stabilize training:

# L2 regularization
weight_decay = 1e-4
optimizer = optim.Adam(energy_model.parameters(), lr=LEARNING_RATE, weight_decay=weight_decay)

# Spectral normalization for stability
from torch.nn.utils import spectral_norm

class RegularizedMLPEnergy(BaseEnergyFunction):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.network = nn.Sequential(
            spectral_norm(nn.Linear(input_dim, hidden_dim)),
            nn.ReLU(),
            spectral_norm(nn.Linear(hidden_dim, hidden_dim)),
            nn.ReLU(),
            spectral_norm(nn.Linear(hidden_dim, 1))
        )

    def forward(self, x):
        return self.network(x).squeeze(-1)

Tips for Successful Training

  1. Start Simple: Begin with a simple energy function and dataset, then increase complexity
  2. Monitor Energy Values: Watch for energy collapse (very negative values) which indicates instability
  3. Adjust Sampling Parameters: Tune MCMC step size and noise scale for effective exploration
  4. Use Persistent CD: For complex distributions, persistent CD often yields better results
  5. Visualize Frequently: Regularly check the energy landscape and samples to track progress
  6. Gradient Clipping: Always use gradient clipping to prevent explosive gradients
  7. Parameter Scheduling: Use schedulers for learning rate, step size, and noise scale
  8. Batch Normalization: Consider adding batch normalization in your energy network
  9. Ensemble Methods: Train multiple models and ensemble their predictions for better results
  10. Patience: EBM training can be challenging - be prepared to experiment with hyperparameters