Training energy-based models involves estimating and minimizing the difference between the model distribution and the data distribution. TorchEBM provides various loss functions to accomplish this.
Contrastive Divergence (CD) is one of the most popular methods for training energy-based models. It uses MCMC sampling to generate negative examples from the current model.
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.lossesimportContrastiveDivergencefromtorchebm.samplersimportLangevinDynamics# Define a custom energy functionclassMLPEnergy(BaseEnergyFunction):def__init__(self,input_dim,hidden_dim=64):super().__init__()self.network=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,1),nn.Tanh(),)defforward(self,x):returnself.network(x).squeeze(-1)# Create energy model, sampler, and loss functiondevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")energy_fn=MLPEnergy(input_dim=2,hidden_dim=64).to(device)# Set up sampler for negative samplessampler=LangevinDynamics(energy_function=energy_fn,step_size=0.1,device=device)# Create Contrastive Divergence lossloss_fn=ContrastiveDivergence(energy_function=energy_fn,sampler=sampler,k_steps=10,# Number of MCMC stepspersistent=False,# Standard CD (non-persistent))# Define optimizeroptimizer=optim.Adam(energy_fn.parameters(),lr=0.001)# During training:data_batch=torch.randn(128,2).to(device)# Your real data batchoptimizer.zero_grad()loss,negative_samples=loss_fn(data_batch)loss.backward()optimizer.step()
# Create Persistent Contrastive Divergence lossloss_fn=ContrastiveDivergence(energy_function=energy_fn,sampler=sampler,k_steps=10,persistent=True,# Enable PCDbuffer_size=1024,# Size of the persistent bufferbuffer_init='rand'# How to initialize the buffer ('rand' or 'data'))
fromtorchebm.coreimportCosineScheduler,ExponentialDecayScheduler,LinearScheduler# Define schedulers for step size and noise scalestep_size_scheduler=CosineScheduler(start_value=3e-2,end_value=5e-3,n_steps=100)noise_scheduler=CosineScheduler(start_value=3e-1,end_value=1e-2,n_steps=100)# Create sampler with schedulerssampler=LangevinDynamics(energy_function=energy_fn,step_size=step_size_scheduler,noise_scale=noise_scheduler,device=device)# Create CD loss with this samplerloss_fn=ContrastiveDivergence(energy_function=energy_fn,sampler=sampler,k_steps=10,persistent=True)
Score Matching is another approach for training EBMs that avoids the need for MCMC sampling. It directly optimizes the score function (gradient of log-density).
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoaderfromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.lossesimportScoreMatchingfromtorchebm.datasetsimportGaussianMixtureDataset# Define a custom energy functionclassMLPEnergy(BaseEnergyFunction):def__init__(self,input_dim,hidden_dim=64):super().__init__()self.net=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.SiLU(),nn.Linear(hidden_dim,hidden_dim),nn.SiLU(),nn.Linear(hidden_dim,1),)defforward(self,x):returnself.net(x).squeeze(-1)# a scalar value# Setup model and devicedevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")energy_fn=MLPEnergy(input_dim=2).to(device)# Create score matching losssm_loss_fn=ScoreMatching(energy_function=energy_fn,hessian_method="hutchinson",# More efficient for higher dimensionshutchinson_samples=5,device=device)# Setup optimizeroptimizer=optim.Adam(energy_fn.parameters(),lr=0.001)# Setup datadataset=GaussianMixtureDataset(n_samples=500,n_components=4,std=0.1,seed=123).get_data()dataloader=DataLoader(dataset,batch_size=32,shuffle=True)# Training loopforepochinrange(10):epoch_loss=0.0forbatch_dataindataloader:batch_data=batch_data.to(device)optimizer.zero_grad()loss=sm_loss_fn(batch_data)loss.backward()optimizer.step()epoch_loss+=loss.item()avg_loss=epoch_loss/len(dataloader)print(f"Epoch {epoch+1}/10, Loss: {avg_loss:.6f}")
sm_loss_fn=ScoreMatching(energy_function=energy_fn,hessian_method="hutchinson",# Use Hutchinson's trickhutchinson_samples=5,# Number of noise samples to usedevice=device)
fromtorchebm.lossesimportDenoisingScoreMatchingdsm_loss_fn=DenoisingScoreMatching(energy_function=energy_fn,sigma=0.1,# Noise leveldevice=device)# During training:optimizer.zero_grad()loss=dsm_loss_fn(data_batch)loss.backward()optimizer.step()
fromtorchebm.lossesimportNoiseContrastiveEstimationimporttorch.distributionsasD# Define a noise distributionnoise_dist=D.Normal(0,1)# Create NCE lossnce_loss_fn=NoiseContrastiveEstimation(energy_function=energy_fn,noise_distribution=noise_dist,noise_samples_per_data=10,device=device)# During training:optimizer.zero_grad()loss=nce_loss_fn(data_batch)loss.backward()optimizer.step()
importtorchimporttorch.nnasnnimporttorch.optimasoptimimportnumpyasnpimportmatplotlib.pyplotaspltfromtorch.utils.dataimportDataLoaderfromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.samplersimportLangevinDynamicsfromtorchebm.lossesimportContrastiveDivergencefromtorchebm.datasetsimportTwoMoonsDataset# Define energy functionclassMLPEnergy(BaseEnergyFunction):def__init__(self,input_dim,hidden_dim=64):super().__init__()self.network=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,1),nn.Tanh(),)defforward(self,x):returnself.network(x).squeeze(-1)# Set devicedevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")# HyperparametersINPUT_DIM=2HIDDEN_DIM=16BATCH_SIZE=256EPOCHS=100LEARNING_RATE=1e-3CD_K=10# MCMC steps for Contrastive DivergenceUSE_PCD=True# Use Persistent Contrastive Divergence# Setup datadataset=TwoMoonsDataset(n_samples=3000,noise=0.05,seed=42,device=device)dataloader=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,drop_last=True)# Create model, sampler, and loss functionenergy_model=MLPEnergy(INPUT_DIM,HIDDEN_DIM).to(device)sampler=LangevinDynamics(energy_function=energy_model,step_size=0.1,device=device,)loss_fn=ContrastiveDivergence(energy_function=energy_model,sampler=sampler,k_steps=CD_K,persistent=USE_PCD,buffer_size=BATCH_SIZE,).to(device)# Optimizeroptimizer=optim.Adam(energy_model.parameters(),lr=LEARNING_RATE)# Training looplosses=[]print("Starting training...")forepochinrange(EPOCHS):energy_model.train()epoch_loss=0.0fori,data_batchinenumerate(dataloader):optimizer.zero_grad()# Calculate Contrastive Divergence lossloss,negative_samples=loss_fn(data_batch)# Backpropagate and optimizeloss.backward()# Optional: Gradient clipping for stabilitytorch.nn.utils.clip_grad_norm_(energy_model.parameters(),max_norm=1.0)optimizer.step()epoch_loss+=loss.item()avg_epoch_loss=epoch_loss/len(dataloader)losses.append(avg_epoch_loss)print(f"Epoch [{epoch+1}/{EPOCHS}], Average Loss: {avg_epoch_loss:.4f}")# Plot the training lossplt.figure(figsize=(10,6))plt.plot(losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss')plt.grid(True,alpha=0.3)plt.tight_layout()plt.savefig('docs/assets/images/loss_functions/cd_training_loss.png')plt.show()