This guide provides detailed information about the implementation of loss functions in TorchEBM, including mathematical foundations, code structure, and optimization techniques.
Energy-based models can be trained using various loss functions, each with different properties. The primary goal is to shape the energy landscape such that observed data has low energy while other regions have high energy.
fromabcimportABC,abstractmethodimporttorchfromtypingimportOptional,Dict,Any,Tuplefromtorchebm.coreimportBaseEnergyFunctionclassBaseLoss(ABC):"""Base class for all loss functions."""def__init__(self,energy_function:BaseEnergyFunction):"""Initialize loss with an energy function. Args: energy_function: The energy function to train """self.energy_function=energy_function@abstractmethoddef__call__(self,pos_samples:torch.Tensor,neg_samples:torch.Tensor,**kwargs)->Tuple[torch.Tensor,Dict[str,torch.Tensor]]:"""Compute the loss. Args: pos_samples: Positive samples from the data distribution neg_samples: Negative samples from the model distribution **kwargs: Additional loss-specific parameters Returns: Tuple of (loss value, dictionary of metrics) """pass
importtorchfromtypingimportDict,Tuplefromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.losses.baseimportBaseLossclassMLELoss(BaseLoss):"""Maximum Likelihood Estimation loss."""def__init__(self,energy_function:BaseEnergyFunction,alpha:float=1.0,regularization:Optional[str]=None,reg_strength:float=0.0):"""Initialize MLE loss. Args: energy_function: Energy function to train alpha: Weight for the negative phase regularization: Type of regularization ('l1', 'l2', or None) reg_strength: Strength of regularization """super().__init__(energy_function)self.alpha=alphaself.regularization=regularizationself.reg_strength=reg_strengthdef__call__(self,pos_samples:torch.Tensor,neg_samples:torch.Tensor,**kwargs)->Tuple[torch.Tensor,Dict[str,torch.Tensor]]:"""Compute the MLE loss. Args: pos_samples: Positive samples from the data distribution neg_samples: Negative samples from the model distribution Returns: Tuple of (loss value, dictionary of metrics) """# Compute energiespos_energy=self.energy_function(pos_samples)neg_energy=self.energy_function(neg_samples)# Compute loss componentspos_term=pos_energy.mean()neg_term=neg_energy.mean()# Full lossloss=pos_term-self.alpha*neg_term# Add regularization if specifiedreg_loss=torch.tensor(0.0,device=pos_energy.device)ifself.regularizationisnotNoneandself.reg_strength>0:ifself.regularization=='l2':forparaminself.energy_function.parameters():reg_loss+=torch.sum(param**2)elifself.regularization=='l1':forparaminself.energy_function.parameters():reg_loss+=torch.sum(torch.abs(param))loss=loss+self.reg_strength*reg_loss# Metrics to trackmetrics={'pos_energy':pos_term.detach(),'neg_energy':neg_term.detach(),'energy_gap':(neg_term-pos_term).detach(),'loss':loss.detach(),'reg_loss':reg_loss.detach()}returnloss,metrics
Contrastive Divergence is a variant of MLE that uses a specific sampling scheme where negative samples are obtained by starting from positive samples and running MCMC for a few steps:
importtorchfromtypingimportDict,Tuple,Optionalfromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.samplersimportSampler,LangevinDynamicsfromtorchebm.losses.baseimportBaseLossclassContrastiveDivergenceLoss(BaseLoss):"""Contrastive Divergence loss."""def__init__(self,energy_function:BaseEnergyFunction,sampler:Optional[Sampler]=None,n_steps:int=10,alpha:float=1.0):"""Initialize CD loss. Args: energy_function: Energy function to train sampler: Sampler for generating negative samples n_steps: Number of sampling steps for negative samples alpha: Weight for the negative phase """super().__init__(energy_function)self.sampler=samplerorLangevinDynamics(energy_function)self.n_steps=n_stepsself.alpha=alphadef__call__(self,pos_samples:torch.Tensor,neg_samples:Optional[torch.Tensor]=None,**kwargs)->Tuple[torch.Tensor,Dict[str,torch.Tensor]]:"""Compute the CD loss. Args: pos_samples: Positive samples from the data distribution neg_samples: Optional negative samples (if None, will be generated) Returns: Tuple of (loss value, dictionary of metrics) """# Generate negative samples if not providedifneg_samplesisNone:withtorch.no_grad():neg_samples=self.sampler.sample(pos_samples.shape[1],self.n_steps,n_samples=pos_samples.shape[0],initial_samples=pos_samples.detach())# Compute energiespos_energy=self.energy_function(pos_samples)neg_energy=self.energy_function(neg_samples)# Compute loss componentspos_term=pos_energy.mean()neg_term=neg_energy.mean()# Full lossloss=pos_term-self.alpha*neg_term# Metrics to trackmetrics={'pos_energy':pos_term.detach(),'neg_energy':neg_term.detach(),'energy_gap':(neg_term-pos_term).detach(),'loss':loss.detach()}returnloss,metrics
importtorchimporttorch.nn.functionalasFfromtypingimportDict,Tuplefromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.losses.baseimportBaseLossclassNCELoss(BaseLoss):"""Noise Contrastive Estimation loss."""def__init__(self,energy_function:BaseEnergyFunction,log_partition:float=0.0,learn_partition:bool=True):"""Initialize NCE loss. Args: energy_function: Energy function to train log_partition: Initial value of log partition function learn_partition: Whether to learn the partition function """super().__init__(energy_function)iflearn_partition:self.log_z=torch.nn.Parameter(torch.tensor([log_partition],dtype=torch.float32))else:self.register_buffer('log_z',torch.tensor([log_partition],dtype=torch.float32))self.learn_partition=learn_partitiondef__call__(self,pos_samples:torch.Tensor,neg_samples:torch.Tensor,**kwargs)->Tuple[torch.Tensor,Dict[str,torch.Tensor]]:"""Compute the NCE loss. Args: pos_samples: Positive samples from the data distribution neg_samples: Negative samples from noise distribution Returns: Tuple of (loss value, dictionary of metrics) """# Compute energiespos_energy=self.energy_function(pos_samples)neg_energy=self.energy_function(neg_samples)# Compute logitspos_logits=-pos_energy-self.log_zneg_logits=-neg_energy-self.log_z# Binary classification losspos_loss=F.binary_cross_entropy_with_logits(pos_logits,torch.ones_like(pos_logits))neg_loss=F.binary_cross_entropy_with_logits(neg_logits,torch.zeros_like(neg_logits))# Full lossloss=pos_loss+neg_loss# Metrics to trackmetrics={'pos_loss':pos_loss.detach(),'neg_loss':neg_loss.detach(),'loss':loss.detach(),'log_z':self.log_z.detach(),'pos_energy':pos_energy.mean().detach(),'neg_energy':neg_energy.mean().detach()}returnloss,metrics
importtorchfromtypingimportDict,Tuplefromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.losses.baseimportBaseLossclassScoreMatchingLoss(BaseLoss):"""Score Matching loss."""def__init__(self,energy_function:BaseEnergyFunction,implicit:bool=True):"""Initialize Score Matching loss. Args: energy_function: Energy function to train implicit: Whether to use implicit score matching """super().__init__(energy_function)self.implicit=implicitdef_compute_explicit_score_matching(self,x:torch.Tensor)->torch.Tensor:"""Compute explicit score matching loss. This requires computing both the score and the Hessian trace. Args: x: Input samples of batch_shape (n_samples, dim) Returns: BaseLoss value """x.requires_grad_(True)# Compute energyenergy=self.energy_function(x)# Compute score (first derivatives)score=torch.autograd.grad(energy.sum(),x,create_graph=True)[0]# Compute trace of Hessian (second derivatives)trace=0.0foriinrange(x.shape[1]):grad_score_i=torch.autograd.grad(score[:,i].sum(),x,create_graph=True)[0]trace+=grad_score_i[:,i]# Compute squared norm of scorescore_norm=torch.sum(score**2,dim=1)# Full lossloss=trace+0.5*score_normreturnloss.mean()def_compute_implicit_score_matching(self,x:torch.Tensor)->torch.Tensor:"""Compute implicit score matching loss. This avoids computing the Hessian trace. Args: x: Input samples of batch_shape (n_samples, dim) Returns: BaseLoss value """# Add noise to inputsx_noise=x+torch.randn_like(x)*0.01x_noise.requires_grad_(True)# Compute energy and its gradientenergy=self.energy_function(x_noise)score=torch.autograd.grad(energy.sum(),x_noise,create_graph=True)[0]# Compute loss as squared difference between gradient and vector fieldvector_field=(x_noise-x)/(0.01**2)loss=0.5*torch.sum((score+vector_field)**2,dim=1)returnloss.mean()def__call__(self,pos_samples:torch.Tensor,neg_samples:torch.Tensor=None,**kwargs)->Tuple[torch.Tensor,Dict[str,torch.Tensor]]:"""Compute the Score Matching loss. Args: pos_samples: Positive samples from the data distribution neg_samples: Not used in Score Matching Returns: Tuple of (loss value, dictionary of metrics) """# Compute loss based on methodifself.implicit:loss=self._compute_implicit_score_matching(pos_samples)else:loss=self._compute_explicit_score_matching(pos_samples)# Metrics to trackmetrics={'loss':loss.detach()}returnloss,metrics
importtorchfromtypingimportDict,Tuple,Union,Listfromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.losses.baseimportBaseLossclassDenoisingScoreMatchingLoss(BaseLoss):"""Denoising Score Matching loss."""def__init__(self,energy_function:BaseEnergyFunction,sigma:Union[float,List[float]]=0.01):"""Initialize DSM loss. Args: energy_function: Energy function to train sigma: Noise level(s) for denoising """super().__init__(energy_function)ifisinstance(sigma,(int,float)):self.sigma=[float(sigma)]else:self.sigma=sigmadef__call__(self,pos_samples:torch.Tensor,neg_samples:torch.Tensor=None,**kwargs)->Tuple[torch.Tensor,Dict[str,torch.Tensor]]:"""Compute the DSM loss. Args: pos_samples: Positive samples from the data distribution neg_samples: Not used in DSM Returns: Tuple of (loss value, dictionary of metrics) """total_loss=0.0metrics={}fori,sigmainenumerate(self.sigma):# Add noise to inputsnoise=torch.randn_like(pos_samples)*sigmax_noisy=pos_samples+noise# Compute score of modelx_noisy.requires_grad_(True)energy=self.energy_function(x_noisy)score_model=torch.autograd.grad(energy.sum(),x_noisy,create_graph=True)[0]# Target score (gradient of log density of noise model)# For Gaussian noise, this is -(x_noisy - pos_samples) / sigma^2score_target=-noise/(sigma**2)# Compute lossloss_sigma=0.5*torch.sum((score_model+score_target)**2,dim=1).mean()total_loss+=loss_sigmametrics[f'loss_sigma_{sigma}']=loss_sigma.detach()# Average loss over all noise levelsavg_loss=total_loss/len(self.sigma)metrics['loss']=avg_loss.detach()returnavg_loss,metrics
importtorchfromtypingimportDict,Tuplefromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.losses.baseimportBaseLossclassSlicedScoreMatchingLoss(BaseLoss):"""Sliced Score Matching loss."""def__init__(self,energy_function:BaseEnergyFunction,n_projections:int=1):"""Initialize SSM loss. Args: energy_function: Energy function to train n_projections: Number of random projections """super().__init__(energy_function)self.n_projections=n_projectionsdef__call__(self,pos_samples:torch.Tensor,neg_samples:torch.Tensor=None,**kwargs)->Tuple[torch.Tensor,Dict[str,torch.Tensor]]:"""Compute the SSM loss. Args: pos_samples: Positive samples from the data distribution neg_samples: Not used in SSM Returns: Tuple of (loss value, dictionary of metrics) """x=pos_samples.detach().requires_grad_(True)# Compute energyenergy=self.energy_function(x)# Compute score (first derivatives)score=torch.autograd.grad(energy.sum(),x,create_graph=True)[0]total_loss=0.0for_inrange(self.n_projections):# Generate random vectorsv=torch.randn_like(x)v=v/torch.norm(v,p=2,dim=1,keepdim=True)# Compute directional derivativeJv=torch.sum(score*v,dim=1)# Compute second directional derivativeJ2v=torch.autograd.grad(Jv.sum(),x,create_graph=True)[0]# Compute sliced score matching loss termsloss_1=torch.sum(J2v*v,dim=1)loss_2=0.5*torch.sum(score**2,dim=1)# Full lossloss=loss_1+loss_2total_loss+=loss.mean()# Average loss over projectionsavg_loss=total_loss/self.n_projections# Metrics to trackmetrics={'loss':avg_loss.detach()}returnavg_loss,metrics
defbatched_hessian_trace(energy_function,x,batch_size=16):"""Compute the trace of the Hessian in batches to save memory."""x.requires_grad_(True)trace=torch.zeros(x.size(0),device=x.device)# Compute energy and scoreenergy=energy_function(x)score=torch.autograd.grad(energy.sum(),x,create_graph=True)[0]# Compute trace of Hessian in batchesforiinrange(0,x.size(1),batch_size):end_i=min(i+batch_size,x.size(1))sub_dims=list(range(i,end_i))forjinsub_dims:# Compute diagonal elements of Hessiangrad_score_j=torch.autograd.grad(score[:,j].sum(),x,create_graph=True)[0]trace+=grad_score_j[:,j]returntrace
defcreate_loss(loss_type:str,energy_function:BaseEnergyFunction,**kwargs)->BaseLoss:"""Create a loss function instance. Args: loss_type: Type of loss function energy_function: Energy function to train **kwargs: BaseLoss-specific parameters Returns: BaseLoss instance """ifloss_type.lower()=='mle':returnMLELoss(energy_function,**kwargs)elifloss_type.lower()=='cd':returnContrastiveDivergenceLoss(energy_function,**kwargs)elifloss_type.lower()=='nce':returnNCELoss(energy_function,**kwargs)elifloss_type.lower()=='sm':returnScoreMatchingLoss(energy_function,**kwargs)elifloss_type.lower()=='dsm':returnDenoisingScoreMatchingLoss(energy_function,**kwargs)elifloss_type.lower()=='ssm':returnSlicedScoreMatchingLoss(energy_function,**kwargs)else:raiseValueError(f"Unknown loss type: {loss_type}")
defvalidate_loss_gradients(loss_fn:BaseLoss,dim:int=2,n_samples:int=10,seed:int=42)->bool:"""Validate that loss function produces valid gradients. Args: loss_fn: BaseLoss function to test dim: Dimensionality of test samples n_samples: Number of test samples seed: Random seed Returns: True if validation passes, False otherwise """torch.manual_seed(seed)# Generate test samplespos_samples=torch.randn(n_samples,dim)neg_samples=torch.randn(n_samples,dim)# Ensure parameters require gradforparaminloss_fn.energy_function.parameters():param.requires_grad_(True)# Compute lossloss,_=loss_fn(pos_samples,neg_samples)# Check if loss is scalarifnotisinstance(loss,torch.Tensor)orloss.numel()!=1:print(f"BaseLoss is not a scalar: {loss}")returnFalse# Check if loss produces gradientstry:loss.backward()has_grad=all(p.gradisnotNoneforpinloss_fn.energy_function.parameters())ifnothas_grad:print("Some parameters did not receive gradients")returnFalseexceptExceptionase:print(f"Error during backward pass: {e}")returnFalsereturnTrue
classCustomLoss(BaseLoss):"""Custom loss example."""def__init__(self,energy_function,alpha=1.0,beta=0.5):super().__init__(energy_function)self.alpha=alphaself.beta=betadef__call__(self,pos_samples,neg_samples,**kwargs):# Compute energiespos_energy=self.energy_function(pos_samples)neg_energy=self.energy_function(neg_samples)# Custom loss logicloss=(pos_energy.mean()-self.alpha*neg_energy.mean())+ \
self.beta*torch.abs(pos_energy.mean()-neg_energy.mean())# Return loss and metricsmetrics={'pos_energy':pos_energy.mean().detach(),'neg_energy':neg_energy.mean().detach(),'loss':loss.detach()}returnloss,metrics