This guide provides detailed information about the implementation of energy functions in TorchEBM, including mathematical foundations, code structure, and optimization techniques.
classBaseEnergyFunction(nn.Module):"""Base class for all energy functions."""defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy for input x."""raiseNotImplementedErrordefscore(self,x:torch.Tensor)->torch.Tensor:"""Compute score (gradient of energy) for input x."""x=x.requires_grad_(True)energy=self.forward(x)returntorch.autograd.grad(energy.sum(),x,create_graph=True)[0]
Key design decisions:
PyTorch nn.Module Base: Allows energy functions to have learnable parameters and use PyTorch's optimization tools
Automatic Differentiation: Uses PyTorch's autograd for computing the score function
Batched Computation: All methods support batched inputs for efficiency
classGaussianEnergy(BaseEnergyFunction):"""Gaussian energy function."""def__init__(self,mean:torch.Tensor,cov:torch.Tensor):"""Initialize Gaussian energy function. Args: mean: Mean vector of batch_shape (dim,) cov: Covariance matrix of batch_shape (dim, dim) """super().__init__()self.register_buffer("mean",mean)self.register_buffer("cov",cov)self.register_buffer("precision",torch.inverse(cov))self._dim=mean.size(0)# Compute log determinant for normalization (optional)self.register_buffer("log_det",torch.logdet(cov))defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute Gaussian energy. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """# Ensure x has the right batch_shapeifx.dim()==1:x=x.unsqueeze(0)# Center the datacentered=x-self.mean# Compute quadratic form efficientlyreturn0.5*torch.sum(centered*torch.matmul(centered,self.precision),dim=1)defscore(self,x:torch.Tensor)->torch.Tensor:"""Compute score function analytically. This is more efficient than using automatic differentiation. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size, dim) containing score values """ifx.dim()==1:x=x.unsqueeze(0)return-torch.matmul(x-self.mean,self.precision)
Implementation notes:
We precompute the precision matrix (inverse covariance) for efficiency
A specialized score method is provided that uses the analytical formula rather than automatic differentiation
Input shape handling ensures both single samples and batches work correctly
classDoubleWellEnergy(BaseEnergyFunction):"""Double well energy function."""def__init__(self,a:float=1.0,b:float=2.0):"""Initialize double well energy function. Args: a: Scale parameter controlling depth of wells b: Parameter controlling the distance between wells """super().__init__()self.a=aself.b=bdefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute double well energy. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """# Compute (x^2 - b)^2 for each dimension, then sumreturnself.a*torch.sum((x**2-self.b)**2,dim=1)
classRosenbrockEnergy(BaseEnergyFunction):"""Rosenbrock energy function."""def__init__(self,a:float=1.0,b:float=100.0):"""Initialize Rosenbrock energy function. Args: a: Scale parameter for the first term b: Scale parameter for the second term (usually 100) """super().__init__()self.a=aself.b=bdefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute Rosenbrock energy. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """ifx.dim()==1:x=x.unsqueeze(0)batch_size,dim=x.shapeenergy=torch.zeros(batch_size,device=x.device)foriinrange(dim-1):term1=self.b*(x[:,i+1]-x[:,i]**2)**2term2=(x[:,i]-1)**2energy+=term1+term2returnenergy
classCompositeEnergy(BaseEnergyFunction):"""Composite energy function."""def__init__(self,energy_functions:List[BaseEnergyFunction],weights:Optional[List[float]]=None,operation:str="sum"):"""Initialize composite energy function. Args: energy_functions: List of energy functions to combine weights: Optional weights for each energy function operation: How to combine energy functions ("sum", "product", "min", "max") """super().__init__()self.energy_functions=nn.ModuleList(energy_functions)ifweightsisNone:weights=[1.0]*len(energy_functions)self.register_buffer("weights",torch.tensor(weights))ifoperationnotin["sum","product","min","max"]:raiseValueError(f"Unknown operation: {operation}")self.operation=operationdefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute composite energy. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """energies=[f(x)*wforf,winzip(self.energy_functions,self.weights)]ifself.operation=="sum":returntorch.sum(torch.stack(energies),dim=0)elifself.operation=="product":returntorch.prod(torch.stack(energies),dim=0)elifself.operation=="min":returntorch.min(torch.stack(energies),dim=0)[0]elifself.operation=="max":returntorch.max(torch.stack(energies),dim=0)[0]
classMLPEnergy(BaseEnergyFunction):"""Multi-layer perceptron energy function."""def__init__(self,input_dim:int,hidden_dims:List[int],activation:Callable=nn.SiLU):"""Initialize MLP energy function. Args: input_dim: Input dimensionality hidden_dims: List of hidden layer dimensions activation: Activation function """super().__init__()# Build MLP layerslayers=[]prev_dim=input_dimforhidden_diminhidden_dims:layers.append(nn.Linear(prev_dim,hidden_dim))layers.append(activation())prev_dim=hidden_dim# Final layer with scalar outputlayers.append(nn.Linear(prev_dim,1))self.network=nn.Sequential(*layers)defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy using the MLP. Args: x: Input tensor of batch_shape (batch_size, input_dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """returnself.network(x).squeeze(-1)
defefficient_grad(energy_fn:BaseEnergyFunction,x:torch.Tensor,create_graph:bool=False)->torch.Tensor:"""Compute gradient of energy function efficiently. Args: energy_fn: Energy function x: Input tensor of batch_shape (batch_size, dim) create_graph: Whether to create gradient graph (for higher-order gradients) Returns: Gradient tensor of batch_shape (batch_size, dim) """x.requires_grad_(True)withtorch.enable_grad():energy=energy_fn(x)grad=torch.autograd.grad(energy.sum(),x,create_graph=create_graph)[0]returngrad
defcuda_score_function(energy_fn,x):"""CUDA-optimized score function computation."""# Use energy_fn's custom CUDA implementation if availableifhasattr(energy_fn,'cuda_score')andtorch.cuda.is_available():returnenergy_fn.cuda_score(x)else:# Fall back to autogradreturnenergy_fn.score(x)
@classmethoddefcreate_standard_gaussian(cls,dim:int)->'GaussianEnergy':"""Create a standard Gaussian energy function. Args: dim: Dimensionality Returns: GaussianEnergy with zero mean and identity covariance """returncls(mean=torch.zeros(dim),cov=torch.eye(dim))@classmethoddeffrom_samples(cls,samples:torch.Tensor,regularization:float=1e-4)->'GaussianEnergy':"""Create a Gaussian energy function from data samples. Args: samples: Data samples of batch_shape (n_samples, dim) regularization: Small value added to diagonal for numerical stability Returns: GaussianEnergy fit to the samples """mean=samples.mean(dim=0)cov=torch.cov(samples.T)+regularization*torch.eye(samples.size(1))returncls(mean=mean,cov=cov)
classNumericallyStableEnergy(BaseEnergyFunction):"""Energy function with numerical stability considerations."""defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy with numerical stability. Uses log-sum-exp trick for numerical stability. """# Example of numerical stability in computationterms=self.compute_terms(x)max_term=torch.max(terms,dim=1,keepdim=True)[0]stable_energy=max_term+torch.log(torch.sum(torch.exp(terms-max_term),dim=1))returnstable_energy
classMixtureEnergy(BaseEnergyFunction):"""Mixture of energy functions."""def__init__(self,components:List[BaseEnergyFunction],weights:Optional[List[float]]=None):"""Initialize mixture energy function. Args: components: List of component energy functions weights: Optional weights for each component """super().__init__()self.components=nn.ModuleList(components)ifweightsisNone:weights=[1.0]*len(components)self.register_buffer("log_weights",torch.log(torch.tensor(weights)))defforward(self,x:torch.Tensor)->torch.Tensor:"""Compute mixture energy using log-sum-exp for stability."""energies=torch.stack([f(x)forfinself.components],dim=1)weighted_energies=-self.log_weights-energies# Use log-sum-exp trick for numerical stabilitymax_val=torch.max(weighted_energies,dim=1,keepdim=True)[0]stable_energy=-max_val-torch.log(torch.sum(torch.exp(weighted_energies-max_val),dim=1))returnstable_energy
deftest_energy_function(energy_fn:BaseEnergyFunction,dim:int,n_samples:int=1000)->dict:"""Test an energy function for correctness and properties. Args: energy_fn: Energy function to test dim: Input dimensionality n_samples: Number of test samples Returns: Dictionary with test results """# Generate random samplesx=torch.randn(n_samples,dim)# Test energy computationenergy=energy_fn(x)assertenergy.shape==(n_samples,)# Test score computationscore=energy_fn.score(x)assertscore.shape==(n_samples,dim)# Test gradient consistencymanual_grad=torch.autograd.grad(energy_fn(x).sum(),x,create_graph=True)[0]asserttorch.allclose(score,-manual_grad,atol=1e-5,rtol=1e-5)return{"energy_mean":energy.mean().item(),"energy_std":energy.std().item(),"score_mean":score.mean().item(),"score_std":score.std().item(),}
classCustomEnergy(BaseEnergyFunction):"""Custom energy function example."""def__init__(self,scale:float=1.0):super().__init__()self.scale=scaledefforward(self,x:torch.Tensor)->torch.Tensor:# Ensure correct input shapeifx.dim()==1:x=x.unsqueeze(0)# Compute energy using vectorized operationsreturnself.scale*torch.sum(torch.sin(x)**2,dim=1)defscore(self,x:torch.Tensor)->torch.Tensor:# Analytical gradientifx.dim()==1:x=x.unsqueeze(0)return-2*self.scale*torch.sin(x)*torch.cos(x)
defdebug_energy_function(energy_fn:BaseEnergyFunction,x:torch.Tensor)->None:"""Debug an energy function for common issues."""# Check for NaN/Inf in energyenergy=energy_fn(x)iftorch.isnan(energy).any()ortorch.isinf(energy).any():print("Warning: Energy contains NaN or Inf values")# Check for NaN/Inf in scorescore=energy_fn.score(x)iftorch.isnan(score).any()ortorch.isinf(score).any():print("Warning: Score contains NaN or Inf values")# Check score magnitudescore_norm=torch.norm(score,dim=1)if(score_norm>1e3).any():print("Warning: Score has very large values")# Check energy rangeifenergy.max()-energy.min()>1e6:print("Warning: Energy has a very large range")
classSphericalEnergy(BaseEnergyFunction):"""Energy function defined on a unit sphere."""def__init__(self,base_energy:BaseEnergyFunction):"""Initialize spherical energy function. Args: base_energy: Base energy function """super().__init__()self.base_energy=base_energydefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy on unit sphere. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """# Project to unit spherex_normalized=F.normalize(x,p=2,dim=1)returnself.base_energy(x_normalized)
classDensityModelEnergy(BaseEnergyFunction):"""Energy function from a density model."""def__init__(self,density_model:Callable):"""Initialize energy function from density model. Args: density_model: Model that computes log probability """super().__init__()self.density_model=density_modeldefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy as negative log probability. Args: x: Input tensor of batch_shape (batch_size, dim) Returns: Tensor of batch_shape (batch_size,) containing energy values """log_prob=self.density_model.log_prob(x)return-log_prob