This guide walks through the complete process of training an energy-based model using TorchEBM. We will train a simple MLP-based model to learn a 2D Gaussian mixture distribution, a classic "hello world" for EBMs that is easy to visualize.
Training an EBM involves minimizing the KL divergence between the data distribution \( p_{data}(x) \) and the model distribution \( p_{\theta}(x) \). The gradient of the log-likelihood is:
The first term pushes the energy down for real data ("positive samples"), and the second term pushes the energy up for data generated by the model ("negative samples").
Since we cannot sample directly from \( p_{\theta} \), we use an MCMC procedure like Langevin Dynamics to generate the negative samples. Contrastive Divergence (CD) is an algorithm that approximates this gradient by running the MCMC chain for only a few steps, initialized from the real data.
# Define the energy model using a simple MLPclassMLPModel(BaseModel):def__init__(self,input_dim:int,hidden_dim:int=128):super().__init__()self.network=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,1),)defforward(self,x:torch.Tensor)->torch.Tensor:returnself.network(x).squeeze(-1)# Set up device and datasetdevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")dataset=GaussianMixtureDataset(n_samples=2048,n_components=8,std=0.1,radius=1.5,device=device,seed=42,)dataloader=DataLoader(dataset,batch_size=256,shuffle=True)
# Modelmodel=MLPModel(input_dim=2).to(device)# Samplersampler=LangevinDynamics(model=model,step_size=0.1,noise_scale=0.1,)# Loss Functionloss_fn=ContrastiveDivergence(model=model,sampler=sampler,n_steps=10# k in CD-k)# Optimizeroptimizer=optim.Adam(model.parameters(),lr=1e-3)
The training loop is a standard PyTorch loop. In each step, we pass a batch of real data to the loss_fn, which performs the following steps internally: 1. Calculates the energy of the real data (positive phase). 2. Initializes MCMC chains from the real data batch. 3. Runs the sampler for n_steps to generate negative samples. 4. Calculates the energy of the negative samples (negative phase). 5. Computes the CD loss and returns it.
Throughout training, it's crucial to visualize the learned energy landscape and the samples generated by the model. This helps diagnose issues and understand how the model is learning.
# Helper function to plot energy landscape and samples@torch.no_grad()defvisualize_training(model,real_data,sampler,epoch):plt.figure(figsize=(8,8))# Create a grid to plot the energy landscapeplot_range=2.5grid_size=100x_coords=torch.linspace(-plot_range,plot_range,grid_size,device=device)y_coords=torch.linspace(-plot_range,plot_range,grid_size,device=device)xv,yv=torch.meshgrid(x_coords,y_coords,indexing="xy")grid=torch.stack([xv.flatten(),yv.flatten()],dim=1)# Get energy values and convert to a probability density for visualizationenergy_values=model(grid).cpu().numpy().reshape(grid_size,grid_size)prob_density=np.exp(-energy_values)# Plot the landscapeplt.contourf(xv.cpu().numpy(),yv.cpu().numpy(),prob_density,levels=50,cmap="viridis")# Generate model samples for visualizationinitial_noise=torch.randn(500,2,device=device)model_samples=sampler.sample(x=initial_noise,n_steps=200).cpu().numpy()# Plot real and model samplesplt.scatter(real_data[:,0],real_data[:,1],s=10,alpha=0.5,label="Real Data",c="white")plt.scatter(model_samples[:,0],model_samples[:,1],s=10,alpha=0.5,label="Model Samples",c="red")plt.title(f"Epoch {epoch}")plt.legend()plt.show()# Visualize after trainingmodel.eval()visualize_training(model,dataset.get_data().cpu().numpy(),sampler,100)
The visualization on the left shows the model early in training, where the energy landscape is still diffuse. On the right, after 100 epochs, the model has learned to assign low energy (high probability, bright regions) to the areas where the data lives, and the model samples (red dots) closely match the real data distribution (white dots).