The torchebm library provides a variety of 2D synthetic datasets through the torchebm.datasets module. These datasets are implemented as PyTorch Dataset classes for easy integration with DataLoaders. This walkthrough explores each dataset class with examples and visualizations.
importtorchimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchebm.datasetsimport(GaussianMixtureDataset,EightGaussiansDataset,TwoMoonsDataset,SwissRollDataset,CircleDataset,CheckerboardDataset,PinwheelDataset,GridDataset)# Set random seed for reproducibilitytorch.manual_seed(42)np.random.seed(42)# Helper function to visualize a datasetdefvisualize_dataset(data,title,figsize=(5,5)):plt.figure(figsize=figsize)plt.scatter(data[:,0],data[:,1],s=5,alpha=0.6)plt.title(title)plt.grid(True,alpha=0.3)plt.axis('equal')plt.tight_layout()plt.show()
# Generate 1000 samples from the 8 Gaussians distributioneight_gauss_dataset=EightGaussiansDataset(n_samples=1000,std=0.02,scale=2.0,seed=42)eight_gauss_data=eight_gauss_dataset.get_data()visualize_dataset(eight_gauss_data,"Eight Gaussians")
# Generate 1000 samples from the Two Moons distributionmoons_dataset=TwoMoonsDataset(n_samples=1000,noise=0.05,seed=42)moons_data=moons_dataset.get_data()visualize_dataset(moons_data,"Two Moons")
# Generate 1000 samples from the Swiss Roll distributionswiss_roll_dataset=SwissRollDataset(n_samples=1000,noise=0.05,arclength=3.0,seed=42)swiss_roll_data=swiss_roll_dataset.get_data()visualize_dataset(swiss_roll_data,"Swiss Roll")
# Generate 1000 samples from a Circle distributioncircle_dataset=CircleDataset(n_samples=1000,noise=0.05,radius=1.0,seed=42)circle_data=circle_dataset.get_data()visualize_dataset(circle_data,"Circle")
# Generate 1000 samples from a Checkerboard distributioncheckerboard_dataset=CheckerboardDataset(n_samples=1000,range_limit=4.0,noise=0.01,seed=42)checkerboard_data=checkerboard_dataset.get_data()visualize_dataset(checkerboard_data,"Checkerboard")
# Generate 1000 samples from a Pinwheel distribution with 5 bladespinwheel_dataset=PinwheelDataset(n_samples=1000,n_classes=5,noise=0.05,radial_scale=2.0,angular_scale=0.1,spiral_scale=5.0,seed=42)pinwheel_data=pinwheel_dataset.get_data()visualize_dataset(pinwheel_data,"Pinwheel (5 blades)")
# Generate a 20x20 grid of pointsgrid_dataset=GridDataset(n_samples_per_dim=20,range_limit=1.0,noise=0.01,seed=42)grid_data=grid_dataset.get_data()visualize_dataset(grid_data,"2D Grid (20x20)")
fromtorch.utils.dataimportDataLoader# Create a datasetdataset=GaussianMixtureDataset(n_samples=2000,n_components=8,std=0.1,seed=42)# Create a DataLoaderdataloader=DataLoader(dataset,batch_size=32,shuffle=True,drop_last=True)# Iterate through batchesforbatchindataloader:# Each batch is a tensor of batch_shape [batch_size, 2]print(f"Batch batch_shape: {batch.shape}")# Process the batch...break# Just showing the first batch
# Create a figure with multiple datasetsplt.figure(figsize=(15,10))# Generate datasetsdatasets=[(GaussianMixtureDataset(1000,8,0.05,seed=42).get_data(),"Gaussian Mixture"),(TwoMoonsDataset(1000,0.05,seed=42).get_data(),"Two Moons"),(SwissRollDataset(1000,0.05,seed=42).get_data(),"Swiss Roll"),(CircleDataset(1000,0.05,seed=42).get_data(),"Circle"),(CheckerboardDataset(1000,4.0,0.01,seed=42).get_data(),"Checkerboard"),(PinwheelDataset(1000,5,0.05,seed=42).get_data(),"Pinwheel")]# Plot each datasetfori,(data,title)inenumerate(datasets):plt.subplot(2,3,i+1)plt.scatter(data[:,0],data[:,1],s=3,alpha=0.6)plt.title(title)plt.grid(True,alpha=0.3)plt.axis('equal')plt.tight_layout()plt.show()
# Generate data on GPU if availabledevice=torch.device('cuda'iftorch.cuda.is_available()else'cpu')gpu_dataset=GaussianMixtureDataset(1000,4,0.1,device=device,seed=42)gpu_data=gpu_dataset.get_data()print(f"Data is on: {gpu_data.device}")
# Importsfromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.samplersimportLangevinDynamicsfromtorchebm.lossesimportContrastiveDivergenceimporttorch.nnasnnimporttorch.optimasoptim# Define an energy functionclassMLPEnergy(BaseEnergyFunction):def__init__(self,input_dim=2,hidden_dim=64):super().__init__()self.network=nn.Sequential(nn.Linear(input_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,hidden_dim),nn.ReLU(),nn.Linear(hidden_dim,1))defforward(self,x):returnself.network(x).squeeze(-1)# Setup trainingdevice=torch.device('cuda'iftorch.cuda.is_available()else'cpu')# Create dataset directly with device specificationdataset=TwoMoonsDataset(n_samples=3000,noise=0.05,seed=42,device=device)dataloader=DataLoader(dataset,batch_size=256,shuffle=True,drop_last=True)# Model componentsenergy_model=MLPEnergy(input_dim=2,hidden_dim=16).to(device)sampler=LangevinDynamics(energy_function=energy_model,step_size=0.1,noise_scale=0.1,device=device)loss_fn=ContrastiveDivergence(energy_function=energy_model,sampler=sampler,n_steps=10).to(device)# Optimizeroptimizer=optim.Adam(energy_model.parameters(),lr=1e-3)# Training loop (simplified)forepochinrange(5):# Just a few epochs for demonstrationfordata_batchindataloader:optimizer.zero_grad()loss,_=loss_fn(data_batch)loss.backward()optimizer.step()print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
Dataset: TwoMoonsDataset placed directly on device
Energy Function: Simple MLP implementing BaseEnergyFunction
Sampler: LangevinDynamics for generating samples
Loss: ContrastiveDivergence for EBM training
Training Loop: Standard PyTorch pattern with DataLoader
Dataset Variety: 8 distinct 2D distributions for different testing scenarios
PyTorch Integration: Built as torch.utils.data.Dataset subclasses
Device Support: Create datasets directly on CPU or GPU
Configurability: Extensive parameterization for all distributions
Reproducibility: Seed support for deterministic generation
These dataset classes provide diverse 2D distributions for testing energy-based models. Each distribution has different characteristics that can challenge different aspects of model learning:
Dataset
Testing Focus
Gaussian Mixtures
Mode-seeking behavior
Two Moons
Non-linear decision boundaries
Swiss Roll & Circle
Manifold learning capabilities
Checkerboard
Multiple modes in regular patterns
Pinwheel
Complex spiral structure with varying density
The class-based implementation provides seamless integration with PyTorch's DataLoader system, making it easy to incorporate these datasets into your training pipeline.