Energy-based models (EBMs) are highly flexible, and one of their key advantages is that the energy function can be parameterized using neural networks. This guide explains how to create and use neural network-based energy functions in TorchEBM.
Neural networks provide a powerful way to represent complex energy landscapes that can't be easily defined analytically. By using neural networks as energy functions:
You can capture complex, high-dimensional distributions
The energy function can be learned from data
You gain the expressivity of modern deep learning architectures
importtorchimporttorch.nnasnnfromtorchebm.coreimportBaseEnergyFunctionclassNeuralNetEnergyFunction(BaseEnergyFunction):def__init__(self,input_dim,hidden_dim=128):super().__init__()# Define neural network architectureself.network=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,1))defforward(self,x):# x has batch_shape (batch_size, input_dim)# Output should have batch_shape (batch_size,)returnself.network(x).squeeze(-1)
importtorchimporttorch.nnasnnfromtorchebm.coreimport(BaseEnergyFunction,CosineScheduler,)fromtorchebm.samplersimportLangevinDynamicsfromtorchebm.lossesimportContrastiveDivergencefromtorchebm.datasetsimportGaussianMixtureDatasetfromtorch.utils.dataimportDataLoader# Set random seeds for reproducibilitySEED=42torch.manual_seed(SEED)# Create a simple MLP energy functionclassMLPEnergyFunction(BaseEnergyFunction):def__init__(self,input_dim=2,hidden_dim=64):super().__init__()self.model=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,1),)defforward(self,x):returnself.model(x).squeeze(-1)# Set devicedevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")# Create the datasetdataset=GaussianMixtureDataset(n_samples=1000,n_components=5,# 5 Gaussian componentsstd=0.1,# Standard deviationradius=1.5,# Radius of the mixturedevice=device,seed=SEED,)# Create dataloaderdataloader=DataLoader(dataset,batch_size=128,shuffle=True)# Create modelmodel=MLPEnergyFunction(input_dim=2,hidden_dim=64).to(device)SAMPLER_NOISE_SCALE=CosineScheduler(initial_value=2e-1,final_value=1e-2,total_steps=50)# Create samplersampler=LangevinDynamics(energy_function=model,step_size=0.01,device=device,noise_scale=SAMPLER_NOISE_SCALE,)# Create loss functionloss_fn=ContrastiveDivergence(energy_function=model,sampler=sampler,k_steps=10,# Number of MCMC stepspersistent=False,# Set to True for Persistent Contrastive Divergencedevice=device,)# Create optimizeroptimizer=torch.optim.Adam(model.parameters(),lr=1e-3)# Training loopn_epochs=200forepochinrange(n_epochs):epoch_loss=0.0forbatchindataloader:# Zero gradientsoptimizer.zero_grad()# Compute loss (automatically handles positive and negative samples)loss,neg_samples=loss_fn(batch)# Backpropagationloss.backward()# Update parametersoptimizer.step()epoch_loss+=loss.item()# Print progress every 10 epochsif(epoch+1)%10==0:print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")# Generate samples from the trained modeldefgenerate_samples(model,n_samples=500):# Create samplersampler=LangevinDynamics(energy_function=model,step_size=0.005,device=device)# Initialize from random noiseinitial_samples=torch.randn(n_samples,2).to(device)# Sample using MCMCwithtorch.no_grad():samples=sampler.sample(initial_state=initial_samples,dim=initial_samples.shape[-1],n_samples=n_samples,n_steps=1000,)returnsamples.cpu()# Generate samplessamples=generate_samples(model)print(f"Generated {len(samples)} samples from the energy-based model")
Example: Convolutional Energy Function for Images¶
For image data, convolutional architectures are more appropriate:
importtorchimporttorch.nnasnnfromtorchebm.coreimportBaseEnergyFunctionclassConvolutionalEnergyFunction(BaseEnergyFunction):def__init__(self,channels=1,width=28,height=28):super().__init__()# Convolutional partself.conv_net=nn.Sequential(nn.Conv2d(channels,32,kernel_size=3,stride=1,padding=1),nn.SELU(),nn.Conv2d(32,32,kernel_size=3,stride=2,padding=1),# 14x14nn.SELU(),nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),nn.SELU(),nn.Conv2d(64,64,kernel_size=3,stride=2,padding=1),# 7x7nn.SELU(),nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),nn.SELU(),nn.Conv2d(128,128,kernel_size=3,stride=2,padding=1),# 4x4nn.SELU(),)# Calculate the size of the flattened featuresfeature_size=128*(width//8)*(height//8)# Final energy outputself.energy_head=nn.Sequential(nn.Flatten(),nn.Linear(feature_size,128),nn.SELU(),nn.Linear(128,1))defforward(self,x):# Ensure x is batched and has correct channel dimensionifx.ndim==3:# Single image with channelsx=x.unsqueeze(0)elifx.ndim==2:# Single grayscale imagex=x.unsqueeze(0).unsqueeze(0)# Extract features and compute energyfeatures=self.conv_net(x)energy=self.energy_head(features).squeeze(-1)returnenergy
importtorchimporttorch.nnasnnfromtorchebm.coreimportBaseEnergyFunction,GaussianEnergyclassCompositionalEnergyFunction(BaseEnergyFunction):def__init__(self,input_dim=2,hidden_dim=64):super().__init__()# Analytical component: Gaussian energyself.analytical_component=GaussianEnergy(mean=torch.zeros(input_dim),cov=torch.eye(input_dim))# Neural network componentself.neural_component=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,1))# Weight for combining componentsself.alpha=nn.Parameter(torch.tensor(0.5))defforward(self,x):# Analytical energyanalytical_energy=self.analytical_component(x)# Neural network energyneural_energy=self.neural_component(x).squeeze(-1)# Combine using learned weight# Use sigmoid to keep alpha between 0 and 1alpha=torch.sigmoid(self.alpha)combined_energy=alpha*analytical_energy+(1-alpha)*neural_energyreturncombined_energy
loss_fn=ContrastiveDivergence(energy_function=model,sampler=sampler,k_steps=10,# Number of MCMC stepspersistent=False,# Set to True for Persistent Contrastive Divergencedevice=device,)deftrain_step_contrastive_divergence(data_batch):# Zero gradientsoptimizer.zero_grad()# Compute loss (automatically handles positive and negative samples)loss,neg_samples=loss_fn(data_batch)# Backpropagationloss.backward()# Update parametersoptimizer.step()returnloss.item()
# Use score matching for trainingsm_loss_fn=ScoreMatching(energy_function=energy_fn,hessian_method="hutchinson",# More efficient for higher dimensionshutchinson_samples=5,device=device,)batch_loss=train_step_contrastive_divergence(data_batch)
Neural network energy functions provide a powerful way to model complex distributions in energy-based models. By leveraging the flexibility of deep learning architectures, you can create expressive energy functions that capture intricate patterns in your data.
Remember to carefully design your architecture, choose appropriate training methods, and monitor the behavior of your energy function during training and sampling.