TorchEBM is built around several core components that form the foundation of the library. This guide provides in-depth information about these components and how they interact.
classBaseEnergyFunction(nn.Module):"""Base class for all energy functions. An energy function maps points in the sample space to scalar energy values. Lower energy corresponds to higher probability density. """defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy for input points. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """raiseNotImplementedErrordefscore(self,x:torch.Tensor)->torch.Tensor:"""Compute score function (gradient of energy) for input points. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size, dim) containing score values """x=x.requires_grad_(True)energy=self.forward(x)returntorch.autograd.grad(energy.sum(),x,create_graph=True)[0]
classGaussianEnergy(BaseEnergyFunction):"""Gaussian energy function. Energy function defined by a multivariate Gaussian distribution: E(x) = 0.5 * (x - mean)^T * precision * (x - mean) """def__init__(self,mean:torch.Tensor,cov:torch.Tensor):"""Initialize Gaussian energy function. Args: mean: Mean vector of shape (dim,) cov: Covariance matrix of shape (dim, dim) """super().__init__()self.register_buffer("mean",mean)self.register_buffer("cov",cov)self.register_buffer("precision",torch.inverse(cov))self._dim=mean.size(0)defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute Gaussian energy. Args: x: Input tensor of shape (batch_size, dim) Returns: Tensor of shape (batch_size,) containing energy values """centered=x-self.meanreturn0.5*torch.sum(centered*(self.precision@centered.T).T,dim=1)
classDoubleWellEnergy(BaseEnergyFunction):"""Double well energy function. Energy function with two local minima: E(x) = a * (x^2 - b)^2 """def__init__(self,a:float=1.0,b:float=2.0):"""Initialize double well energy function. Args: a: Scale parameter b: Parameter controlling the distance between wells """super().__init__()self.a=aself.b=bdefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute double well energy. Args: x: Input tensor of shape (batch_size, dim) Returns: Tensor of shape (batch_size,) containing energy values """returnself.a*torch.sum((x**2-self.b)**2,dim=1)
classCompositeEnergy(BaseEnergyFunction):"""Composite energy function. Combines multiple energy functions through addition. """def__init__(self,energy_functions:List[BaseEnergyFunction],weights:Optional[List[float]]=None):"""Initialize composite energy function. Args: energy_functions: List of energy functions to combine weights: Optional weights for each energy function """super().__init__()self.energy_functions=nn.ModuleList(energy_functions)ifweightsisNone:weights=[1.0]*len(energy_functions)self.weights=weightsdefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute composite energy. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """returnsum(w*f(x)forw,finzip(self.weights,self.energy_functions))
classSampler(ABC):"""Base class for all samplers. A sampler generates samples from an energy-based distribution. """def__init__(self,energy_function:BaseEnergyFunction):"""Initialize sampler. Args: energy_function: Energy function to sample from """self.energy_function=energy_function@abstractmethoddefsample(self,n_samples:int,**kwargs)->torch.Tensor:"""Generate samples from the energy-based distribution. Args: n_samples: Number of samples to generate **kwargs: Additional sampler-specific parameters Returns: Tensor of batch_shape (n_samples, dim) containing samples """pass@abstractmethoddefsample_chain(self,dim:int,n_steps:int,n_samples:int=1,**kwargs)->torch.Tensor:"""Generate samples using a Markov chain. Args: dim: Dimensionality of samples n_steps: Number of steps in the chain n_samples: Number of parallel chains to run **kwargs: Additional sampler-specific parameters Returns: Tensor of batch_shape (n_samples, dim) containing final samples """pass
classLangevinDynamics(Sampler):"""Langevin dynamics sampler. Uses Langevin dynamics to sample from an energy-based distribution. """def__init__(self,energy_function:BaseEnergyFunction,step_size:float=0.01,noise_scale:float=1.0):"""Initialize Langevin dynamics sampler. Args: energy_function: Energy function to sample from step_size: Step size for updates noise_scale: Scale of noise added at each step """super().__init__(energy_function)self.step_size=step_sizeself.noise_scale=noise_scaledefsample_step(self,x:torch.Tensor)->torch.Tensor:"""Perform one step of Langevin dynamics. Args: x: Current samples of batch_shape (n_samples, dim) Returns: Updated samples of batch_shape (n_samples, dim) """# Compute score (gradient of energy)score=self.energy_function.score(x)# Update samplesnoise=torch.randn_like(x)*np.sqrt(2*self.step_size*self.noise_scale)x_new=x-self.step_size*score+noisereturnx_newdefsample_chain(self,dim:int,n_steps:int,n_samples:int=1,initial_samples:Optional[torch.Tensor]=None,return_trajectory:bool=False)->Union[torch.Tensor,Tuple[torch.Tensor,torch.Tensor]]:"""Generate samples using a Langevin dynamics chain. Args: dim: Dimensionality of samples n_steps: Number of steps in the chain n_samples: Number of parallel chains to run initial_samples: Optional initial samples return_trajectory: Whether to return the full trajectory Returns: Tensor of batch_shape (n_samples, dim) containing final samples, or a tuple of (samples, trajectory) if return_trajectory is True """# Initialize samplesifinitial_samplesisNone:x=torch.randn(n_samples,dim)else:x=initial_samples.clone()# Initialize trajectory if neededifreturn_trajectory:trajectory=torch.zeros(n_steps+1,n_samples,dim)trajectory[0]=x# Run chainforiinrange(n_steps):x=self.sample_step(x)ifreturn_trajectory:trajectory[i+1]=xifreturn_trajectory:returnx,trajectoryelse:returnxdefsample(self,n_samples:int,dim:int,n_steps:int=100,**kwargs)->torch.Tensor:"""Generate samples from the energy-based distribution. Args: n_samples: Number of samples to generate dim: Dimensionality of samples n_steps: Number of steps in the chain **kwargs: Additional parameters passed to sample_chain Returns: Tensor of batch_shape (n_samples, dim) containing samples """returnself.sample_chain(dim=dim,n_steps=n_steps,n_samples=n_samples,**kwargs)
classBaseLoss(ABC):"""Base class for all loss functions. A loss function computes a loss value for an energy-based model. """@abstractmethoddef__call__(self,model:nn.Module,data_samples:torch.Tensor,**kwargs)->torch.Tensor:"""Compute loss for the model. Args: model: Energy-based model data_samples: Samples from the target distribution **kwargs: Additional loss-specific parameters Returns: Scalar loss value """pass
classContrastiveDivergence(BaseLoss):"""Contrastive divergence loss. Uses contrastive divergence to train energy-based models. """def__init__(self,sampler:Sampler,k:int=1,batch_size:Optional[int]=None):"""Initialize contrastive divergence loss. Args: sampler: Sampler to generate model samples k: Number of sampling steps (CD-k_steps) batch_size: Optional batch size for sampling """super().__init__()self.sampler=samplerself.k=kself.batch_size=batch_sizedef__call__(self,model:nn.Module,data_samples:torch.Tensor,**kwargs)->torch.Tensor:"""Compute contrastive divergence loss. Args: model: Energy-based model data_samples: Samples from the target distribution **kwargs: Additional parameters passed to the sampler Returns: Scalar loss value """# Get data statisticsbatch_size=self.batch_sizeordata_samples.size(0)dim=data_samples.size(1)# Set the model as the sampler's energy functionself.sampler.energy_function=model# Generate model samplesmodel_samples=self.sampler.sample(dim=dim,n_steps=self.k,n_samples=batch_size,**kwargs)# Compute energiesdata_energy=model(data_samples).mean()model_energy=model(model_samples).mean()# Compute lossloss=data_energy-model_energyreturnloss
classEnergyModel(BaseEnergyFunction):"""Neural network-based energy model. Uses a neural network to parameterize an energy function. """def__init__(self,network:nn.Module):"""Initialize energy model. Args: network: Neural network that outputs scalar energy values """super().__init__()self.network=networkdefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy using the neural network. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """returnself.network(x).squeeze(-1)
The following diagram illustrates how the core components interact:
graph TD
A[Energy Function] -->|Defines landscape| B[Sampler]
B -->|Generates samples| C[Training Process]
D[BaseLoss Function] -->|Guides training| C
C -->|Updates| E[Energy Model]
E -->|Parameterizes| A
# Define energy functionenergy_fn=GaussianEnergy(mean=torch.zeros(2),cov=torch.eye(2))# Create samplersampler=LangevinDynamics(energy_function=energy_fn,step_size=0.01)# Generate samplessamples=sampler.sample(dim=2,n_steps=1000,n_samples=100)# Create and train a modelmodel=EnergyModel(network=MLP(input_dim=2,hidden_dims=[32,32],output_dim=1))loss_fn=ContrastiveDivergence(sampler=sampler,k=10)# Training loopoptimizer=torch.optim.Adam(model.parameters(),lr=0.001)forepochinrange(100):optimizer.zero_grad()loss=loss_fn(model,data_samples)loss.backward()optimizer.step()