Custom Neural Network Energy Functions¶
Energy-based models (EBMs) are extremely flexible, and one of their key advantages is that the energy function can be parameterized using neural networks. This guide explains how to create and use neural network-based energy functions in TorchEBM.
Overview¶
Neural networks provide a powerful way to represent complex energy landscapes that can't be easily defined analytically. By using neural networks as energy functions:
- You can capture complex, high-dimensional distributions
- The energy function can be learned from data
- You gain the expressivity of modern deep learning architectures
Basic Neural Network Energy Function¶
To create a neural network-based energy function in TorchEBM, you need to subclass the EnergyFunction
base class and implement the forward
method:
import torch
import torch.nn as nn
from torchebm.core import EnergyFunction
class NeuralNetEnergyFunction(EnergyFunction):
def __init__(self, input_dim, hidden_dim=128):
super().__init__()
# Define neural network architecture
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.Softplus(),
nn.Linear(hidden_dim, hidden_dim),
nn.Softplus(),
nn.Linear(hidden_dim, hidden_dim),
nn.Softplus(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
# x has shape (batch_size, input_dim)
# Output should have shape (batch_size,)
return self.network(x).squeeze(-1)
Design Considerations¶
When designing neural network energy functions, consider the following:
Network Architecture¶
The choice of architecture depends on the data type and complexity:
- MLPs: Good for generic, low-dimensional data
- CNNs: Effective for images and data with spatial structure
- Transformers: Useful for sequential data or when attention mechanisms are beneficial
- Graph Neural Networks: For data with graph structure
Output Requirements¶
Remember the following key points:
- The energy function should output a scalar value for each sample in the batch
- Lower energy values should correspond to higher probability density
- The neural network must be differentiable for gradient-based sampling methods to work
Scale and Normalization¶
Energy values should be properly scaled to avoid numerical issues:
- Very large energy values can cause instability in sampling
- Energy functions that grow too quickly may cause sampling algorithms to fail
Example: MLP Energy Function for 2D Data¶
Here's a complete example with a simple MLP energy function:
import torch
import torch.nn as nn
from torchebm.core import EnergyFunction
from torchebm.samplers.langevin_dynamics import LangevinDynamics
import matplotlib.pyplot as plt
import numpy as np
class MLPEnergyFunction(EnergyFunction):
def __init__(self, input_dim=2, hidden_dim=64):
super().__init__()
# Define neural network architecture
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1)
)
# Initialize with small weights
for m in self.network.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight, gain=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
# Ensure x is batched
if x.ndim == 1:
x = x.unsqueeze(0)
# Forward pass through network
return self.network(x).squeeze(-1)
# Create the energy function
energy_fn = MLPEnergyFunction()
# Define parameters we want the network to learn
# Let's create a "four peaks" energy landscape
def target_energy(x, y):
return -2.0 * torch.exp(-0.2 * ((x - 2)**2 + (y - 2)**2)) \
-3.0 * torch.exp(-0.2 * ((x + 2)**2 + (y - 2)**2)) \
-1.0 * torch.exp(-0.3 * ((x - 2)**2 + (y + 2)**2)) \
-4.0 * torch.exp(-0.2 * ((x + 2)**2 + (y + 2)**2)) \
+ 0.1 * (x**2 + y**2)
# Generate training data from the target distribution
def generate_training_data(n_samples=10000):
# Sample uniformly from a grid
x = torch.linspace(-4, 4, 100)
y = torch.linspace(-4, 4, 100)
X, Y = torch.meshgrid(x, y, indexing='ij')
positions = torch.stack([X.flatten(), Y.flatten()], dim=1)
# Calculate target energy
energies = target_energy(positions[:, 0], positions[:, 1])
# Convert energies to probabilities (unnormalized)
probs = torch.exp(-energies)
# Normalize to create a distribution
probs = probs / probs.sum()
# Sample indices based on probability
indices = torch.multinomial(probs, n_samples, replacement=True)
# Return sampled positions
return positions[indices]
# Generate training data
train_data = generate_training_data(10000)
# Set up optimizer
optimizer = torch.optim.Adam(energy_fn.parameters(), lr=0.001)
# Training loop
n_epochs = 1000
batch_size = 128
for epoch in range(n_epochs):
# Generate random noise samples for contrastive divergence
noise_samples = torch.randn_like(train_data)
# Shuffle data
indices = torch.randperm(train_data.shape[0])
# Mini-batch training
for start_idx in range(0, train_data.shape[0], batch_size):
end_idx = min(start_idx + batch_size, train_data.shape[0])
batch_indices = indices[start_idx:end_idx]
data_batch = train_data[batch_indices]
noise_batch = noise_samples[batch_indices]
# Zero gradients
optimizer.zero_grad()
# Calculate energy for data and noise samples
data_energy = energy_fn(data_batch)
noise_energy = energy_fn(noise_batch)
# Contrastive divergence loss: make data energy lower, noise energy higher
loss = data_energy.mean() - noise_energy.mean()
# Backpropagation
loss.backward()
optimizer.step()
# Print progress
if (epoch + 1) % 100 == 0:
print(f'Epoch {epoch+1}/{n_epochs}, Loss: {loss.item():.4f}')
# Visualize learned energy function
def visualize_energy_function(energy_fn, title="Learned Energy Function"):
x = torch.linspace(-4, 4, 100)
y = torch.linspace(-4, 4, 100)
X, Y = torch.meshgrid(x, y, indexing='ij')
positions = torch.stack([X.flatten(), Y.flatten()], dim=1)
# Calculate energies
with torch.no_grad():
energies = energy_fn(positions).reshape(100, 100)
# Plot
plt.figure(figsize=(10, 8))
plt.contourf(X.numpy(), Y.numpy(), energies.numpy(), 50, cmap='viridis')
plt.colorbar(label='Energy')
plt.title(title)
plt.xlabel('x')
plt.ylabel('y')
plt.tight_layout()
plt.show()
# Visualize the learned energy function
visualize_energy_function(energy_fn)
# Sample from the learned energy function using Langevin dynamics
sampler = LangevinDynamics(
energy_function=energy_fn,
step_size=0.01
)
samples = sampler.sample_chain(
dim=2,
n_steps=1000,
n_samples=2000,
burn_in=200
)
# Visualize samples
plt.figure(figsize=(10, 8))
plt.scatter(samples[:, 0].numpy(), samples[:, 1].numpy(), s=1, alpha=0.5)
plt.xlim(-4, 4)
plt.ylim(-4, 4)
plt.title('Samples from Learned Energy Function')
plt.xlabel('x')
plt.ylabel('y')
plt.tight_layout()
plt.show()
Example: Convolutional Energy Function for Images¶
For image data, convolutional architectures are more appropriate:
import torch
import torch.nn as nn
from torchebm.core import EnergyFunction
class ConvolutionalEnergyFunction(EnergyFunction):
def __init__(self, channels=1, width=28, height=28):
super().__init__()
# Convolutional feature extractor
self.feature_extractor = nn.Sequential(
nn.Conv2d(channels, 32, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), # 14x14
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), # 7x7
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), # 4x4
nn.LeakyReLU(0.2),
)
# Calculate the size of the flattened features
feature_size = 128 * (width // 8) * (height // 8)
# Final energy output
self.energy_head = nn.Sequential(
nn.Flatten(),
nn.Linear(feature_size, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1)
)
def forward(self, x):
# Ensure x is batched and has correct channel dimension
if x.ndim == 3: # Single image with channels
x = x.unsqueeze(0)
elif x.ndim == 2: # Single grayscale image
x = x.unsqueeze(0).unsqueeze(0)
# Extract features and compute energy
features = self.feature_extractor(x)
energy = self.energy_head(features).squeeze(-1)
return energy
Advanced Pattern: Hybrid Energy Functions¶
You can combine analytical energy functions with neural networks for best of both worlds:
import torch
import torch.nn as nn
from torchebm.core import EnergyFunction, GaussianEnergy
class HybridEnergyFunction(EnergyFunction):
def __init__(self, input_dim=2, hidden_dim=64):
super().__init__()
# Analytical component: Gaussian energy
self.analytical_component = GaussianEnergy(
mean=torch.zeros(input_dim),
cov=torch.eye(input_dim)
)
# Neural network component
self.neural_component = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 1)
)
# Weight for combining components
self.alpha = nn.Parameter(torch.tensor(0.5))
def forward(self, x):
# Analytical energy
analytical_energy = self.analytical_component(x)
# Neural network energy
neural_energy = self.neural_component(x).squeeze(-1)
# Combine using learned weight
# Use sigmoid to keep alpha between 0 and 1
alpha = torch.sigmoid(self.alpha)
combined_energy = alpha * analytical_energy + (1 - alpha) * neural_energy
return combined_energy
Training Strategies¶
Training neural network energy functions requires special techniques:
Contrastive Divergence¶
A common approach is contrastive divergence, which minimizes the energy of data samples while maximizing the energy of samples from the model:
def train_step_contrastive_divergence(energy_fn, optimizer, data_batch, sampler, n_sampling_steps=10):
# Zero gradients
optimizer.zero_grad()
# Data energy
data_energy = energy_fn(data_batch)
# Generate negative samples (model samples)
with torch.no_grad():
# Start from random noise
model_samples = torch.randn_like(data_batch)
# Run MCMC for a few steps
model_samples = sampler.sample_chain(
initial_points=model_samples,
n_steps=n_sampling_steps,
return_final=True
)
# Model energy
model_energy = energy_fn(model_samples)
# Loss: make data energy lower, model energy higher
loss = data_energy.mean() - model_energy.mean()
# Backpropagation
loss.backward()
optimizer.step()
return loss.item()
Score Matching¶
Score matching is another approach that avoids the need for MCMC sampling:
def score_matching_loss(energy_fn, data_batch, noise_scale=0.01):
# Add noise to data
data_batch.requires_grad_(True)
# Compute energy
energy = energy_fn(data_batch)
# Compute gradients w.r.t. inputs
grad_energy = torch.autograd.grad(
outputs=energy.sum(),
inputs=data_batch,
create_graph=True,
retain_graph=True
)[0]
# Compute score matching loss
loss = 0.5 * (grad_energy ** 2).sum(dim=1).mean()
# Add regularization term
noise_data = data_batch + noise_scale * torch.randn_like(data_batch)
noise_energy = energy_fn(noise_data)
reg_loss = ((noise_energy - energy) ** 2).mean()
return loss + 0.1 * reg_loss
Tips for Neural Network Energy Functions¶
- Start Simple: Begin with a simple architecture and gradually increase complexity
- Regularization: Use weight decay or spectral normalization to prevent extreme energy values
- Gradient Clipping: Apply gradient clipping during training to prevent instability
- Initialization: Careful initialization of weights can help convergence
- Monitoring: Track energy values during training to ensure they stay in a reasonable range
- Batch Normalization: Use with caution as it can affect the shape of the energy landscape
- Residual Connections: Can help with gradient flow in deeper networks
Conclusion¶
Neural network energy functions provide a powerful way to model complex distributions in energy-based models. By leveraging the flexibility of deep learning architectures, you can create expressive energy functions that capture intricate patterns in your data.
Remember to carefully design your architecture, choose appropriate training methods, and monitor the behavior of your energy function during training and sampling.