This guide covers the fundamental techniques for training energy-based models (EBMs) using TorchEBM. We'll explore various training methods, loss functions, and optimization strategies to help you effectively train your models.
Training energy-based models involves estimating the parameters of a model such that the corresponding probability distribution matches a target data distribution. Unlike in traditional supervised learning, this is often an unsupervised task where the goal is to learn the underlying structure of the data.
The training process typically involves:
Defining a model (parameterized by a neural network or analytical form)
Choosing a training method and loss function
Optimizing the model parameters
Evaluating the model using sampling and visualization techniques
importtorchimporttorch.nnasnnimporttorch.optimasoptimimportnumpyasnpimportmatplotlib.pyplotaspltfromtorch.utils.dataimportDataLoaderimportosfromtorchebm.coreimportBaseModel,CosineSchedulerfromtorchebm.samplersimportLangevinDynamicsfromtorchebm.lossesimportContrastiveDivergencefromtorchebm.datasetsimportTwoMoonsDatasettorch.manual_seed(42)np.random.seed(42)iftorch.cuda.is_available():torch.cuda.manual_seed(42)os.makedirs("training_plots",exist_ok=True)INPUT_DIM=2HIDDEN_DIM=16BATCH_SIZE=256EPOCHS=200LEARNING_RATE=1e-3SAMPLER_STEP_SIZE=CosineScheduler(start_value=3e-2,end_value=5e-3,n_steps=100)SAMPLER_NOISE_SCALE=CosineScheduler(start_value=3e-1,end_value=1e-2,n_steps=100)CD_K=10USE_PCD=TrueVISUALIZE_EVERY=20device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")print(f"Using device: {device}")dataset=TwoMoonsDataset(n_samples=3000,noise=0.05,seed=42,device=device)real_data_for_plotting=dataset.get_data()dataloader=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,drop_last=True,)model=MLPModel(INPUT_DIM,HIDDEN_DIM).to(device)sampler=LangevinDynamics(model=model,step_size=SAMPLER_STEP_SIZE,noise_scale=SAMPLER_NOISE_SCALE,device=device,)loss_fn=ContrastiveDivergence(model=model,sampler=sampler,k_steps=CD_K,persistent=USE_PCD,buffer_size=BATCH_SIZE,).to(device)optimizer=optim.Adam(model.parameters(),lr=LEARNING_RATE)losses=[]print("Starting training...")forepochinrange(EPOCHS):model.train()epoch_loss=0.0fori,data_batchinenumerate(dataloader):optimizer.zero_grad()loss,negative_samples=loss_fn(data_batch)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1.0)optimizer.step()epoch_loss+=loss.item()avg_epoch_loss=epoch_loss/len(dataloader)losses.append(avg_epoch_loss)print(f"Epoch [{epoch+1}/{EPOCHS}], Average Loss: {avg_epoch_loss:.4f}")if(epoch+1)%VISUALIZE_EVERY==0orepoch==0:print("Generating visualization...")plot_energy_and_samples(model=model,real_samples=real_data_for_plotting,sampler=sampler,epoch=epoch+1,device=device,plot_range=2.5,k_sampling=200,)# Plot the training lossplt.figure(figsize=(10,6))plt.plot(losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss')plt.grid(True,alpha=0.3)plt.savefig('docs/assets/images/training/cd_training_loss.png')plt.show()
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoaderfromtorchebm.coreimportBaseModelfromtorchebm.lossesimportScoreMatchingfromtorchebm.datasetsimportGaussianMixtureDatasetdevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")model=MLPModel(input_dim=2).to(device)sm_loss_fn=ScoreMatching(model=model,hessian_method="hutchinson",hutchinson_samples=5,device=device,)optimizer=optim.Adam(model.parameters(),lr=0.001)dataset=GaussianMixtureDataset(n_samples=500,n_components=4,std=0.1,seed=123).get_data()dataloader=DataLoader(dataset,batch_size=32,shuffle=True)losses=[]forepochinrange(50):epoch_loss=0.0forbatch_dataindataloader:batch_data=batch_data.to(device)optimizer.zero_grad()loss=sm_loss_fn(batch_data)loss.backward()optimizer.step()epoch_loss+=loss.item()avg_loss=epoch_loss/len(dataloader)losses.append(avg_loss)print(f"Epoch {epoch+1}/50, Loss: {avg_loss:.6f}")# Plot the training lossplt.figure(figsize=(10,6))plt.plot(losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Score Matching Training Loss')plt.grid(True,alpha=0.3)plt.savefig('docs/assets/images/training/sm_training_loss.png')plt.show()