Energy-based models (EBMs) are highly flexible, and one of their key advantages is that the model can be parameterized using neural networks. This guide explains how to create and use neural network-based models in TorchEBM.
Neural networks provide a powerful way to represent complex energy landscapes that can't be easily defined analytically. By using neural networks as models:
You can capture complex, high-dimensional distributions
The model can be learned from data
You gain the expressivity of modern deep learning architectures
importtorchimporttorch.nnasnnfromtorchebm.coreimport(BaseModel,CosineScheduler,)fromtorchebm.samplersimportLangevinDynamicsfromtorchebm.lossesimportContrastiveDivergencefromtorchebm.datasetsimportGaussianMixtureDatasetfromtorch.utils.dataimportDataLoaderSEED=42torch.manual_seed(SEED)classMLPModel(BaseModel):def__init__(self,input_dim=2,hidden_dim=64):super().__init__()self.model=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,hidden_dim),nn.SELU(),nn.Linear(hidden_dim,1),)defforward(self,x):returnself.model(x).squeeze(-1)device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")dataset=GaussianMixtureDataset(n_samples=1000,n_components=5,std=0.1,radius=1.5,device=device,seed=SEED,)dataloader=DataLoader(dataset,batch_size=128,shuffle=True)model=MLPModel(input_dim=2,hidden_dim=64).to(device)SAMPLER_NOISE_SCALE=CosineScheduler(initial_value=2e-1,final_value=1e-2,total_steps=50)sampler=LangevinDynamics(model=model,step_size=0.01,device=device,noise_scale=SAMPLER_NOISE_SCALE,)loss_fn=ContrastiveDivergence(model=model,sampler=sampler,k_steps=10,persistent=False,device=device,)optimizer=torch.optim.Adam(model.parameters(),lr=1e-3)n_epochs=200forepochinrange(n_epochs):epoch_loss=0.0forbatchindataloader:optimizer.zero_grad()loss,neg_samples=loss_fn(batch)loss.backward()optimizer.step()epoch_loss+=loss.item()if(epoch+1)%10==0:print(f"Epoch {epoch+1}/{n_epochs}, Loss: {epoch_loss/len(dataloader):.4f}")defgenerate_samples(model,n_samples=500):sampler=LangevinDynamics(model=model,step_size=0.005,device=device)initial_samples=torch.randn(n_samples,2).to(device)withtorch.no_grad():samples=sampler.sample(initial_state=initial_samples,dim=initial_samples.shape[-1],n_samples=n_samples,n_steps=1000,)returnsamples.cpu()samples=generate_samples(model)print(f"Generated {len(samples)} samples from the energy-based model")
loss_fn=ContrastiveDivergence(model=model,sampler=sampler,k_steps=10,# Number of MCMC stepspersistent=False,# Set to True for Persistent Contrastive Divergencedevice=device,)deftrain_step_contrastive_divergence(data_batch):optimizer.zero_grad()loss,neg_samples=loss_fn(data_batch)loss.backward()optimizer.step()returnloss.item()
Neural network models provide a powerful way to model complex distributions in energy-based models. By leveraging the flexibility of deep learning architectures, you can create expressive models that capture intricate patterns in your data.
Remember to carefully design your architecture, choose appropriate training methods, and monitor the behavior of your model during training and sampling.