This guide provides detailed information about the implementation of sampling algorithms in TorchEBM, including mathematical foundations, code structure, and optimization techniques.
fromabcimportABC,abstractmethodimporttorchfromtypingimportOptional,Union,Tuplefromtorchebm.coreimportBaseEnergyFunctionclassSampler(ABC):"""Base class for all sampling algorithms."""def__init__(self,energy_function:BaseEnergyFunction):"""Initialize sampler with an energy function. Args: energy_function: The energy function to sample from """self.energy_function=energy_functionself.device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")defto(self,device):"""Move sampler to specified device."""self.device=devicereturnself@abstractmethoddefsample(self,n_samples:int,**kwargs)->torch.Tensor:"""Generate samples from the energy-based distribution. Args: n_samples: Number of samples to generate **kwargs: Additional sampler-specific parameters Returns: Tensor of batch_shape (n_samples, dim) containing samples """pass@abstractmethoddefsample_chain(self,dim:int,n_steps:int,n_samples:int=1,**kwargs)->torch.Tensor:"""Generate samples using a Markov chain. Args: dim: Dimensionality of samples n_steps: Number of steps in the chain n_samples: Number of parallel chains to run **kwargs: Additional sampler-specific parameters Returns: Tensor of batch_shape (n_samples, dim) containing final samples """pass
importtorchimportnumpyasnpfromtypingimportOptional,Union,Tuplefromtorchebm.coreimportBaseEnergyFunctionfromtorchebm.samplers.baseimportSamplerclassLangevinDynamics(Sampler):"""Langevin dynamics sampler."""def__init__(self,energy_function:BaseEnergyFunction,step_size:float=0.01,noise_scale:float=1.0):"""Initialize Langevin dynamics sampler. Args: energy_function: Energy function to sample from step_size: Step size for updates noise_scale: Scale of noise added at each step """super().__init__(energy_function)self.step_size=step_sizeself.noise_scale=noise_scaledefsample_step(self,x:torch.Tensor)->torch.Tensor:"""Perform one step of Langevin dynamics. Args: x: Current samples of batch_shape (n_samples, dim) Returns: Updated samples of batch_shape (n_samples, dim) """# Compute score (gradient of log probability)score=-self.energy_function.score(x)# Add drift term and noisenoise=torch.randn_like(x)*np.sqrt(2*self.step_size*self.noise_scale)x_new=x+self.step_size*score+noisereturnx_newdefsample_chain(self,dim:int,n_steps:int,n_samples:int=1,initial_samples:Optional[torch.Tensor]=None,return_trajectory:bool=False)->Union[torch.Tensor,Tuple[torch.Tensor,torch.Tensor]]:"""Generate samples using a Langevin dynamics chain. Args: dim: Dimensionality of samples n_steps: Number of steps in the chain n_samples: Number of parallel chains to run initial_samples: Optional initial samples return_trajectory: Whether to return the full trajectory Returns: Samples or (samples, trajectory) """# Initialize samplesifinitial_samplesisNone:x=torch.randn(n_samples,dim,device=self.device)else:x=initial_samples.clone().to(self.device)# Initialize trajectory if neededifreturn_trajectory:trajectory=torch.zeros(n_steps+1,n_samples,dim,device=self.device)trajectory[0]=x# Run sampling chainforiinrange(n_steps):x=self.sample_step(x)ifreturn_trajectory:trajectory[i+1]=xifreturn_trajectory:returnx,trajectoryelse:returnxdefsample(self,n_samples:int,dim:int,n_steps:int=100,**kwargs)->torch.Tensor:"""Generate samples from the energy-based distribution."""returnself.sample_chain(dim=dim,n_steps=n_steps,n_samples=n_samples,**kwargs)
classHamiltonianMonteCarlo(Sampler):"""Hamiltonian Monte Carlo sampler."""def__init__(self,energy_function:BaseEnergyFunction,step_size:float=0.1,n_leapfrog_steps:int=10,mass_matrix:Optional[torch.Tensor]=None):"""Initialize HMC sampler. Args: energy_function: Energy function to sample from step_size: Step size for leapfrog integration n_leapfrog_steps: Number of leapfrog steps mass_matrix: Mass matrix for momentum (identity by default) """super().__init__(energy_function)self.step_size=step_sizeself.n_leapfrog_steps=n_leapfrog_stepsself.mass_matrix=mass_matrixdef_leapfrog_step(self,x:torch.Tensor,p:torch.Tensor)->Tuple[torch.Tensor,torch.Tensor]:"""Perform one leapfrog step. Args: x: Position tensor of batch_shape (n_samples, dim) p: Momentum tensor of batch_shape (n_samples, dim) Returns: New position and momentum """# Half step for momentumgrad_x=self.energy_function.score(x)p=p-0.5*self.step_size*grad_x# Full step for positionifself.mass_matrixisnotNone:x=x+self.step_size*torch.matmul(p,self.mass_matrix)else:x=x+self.step_size*p# Half step for momentumgrad_x=self.energy_function.score(x)p=p-0.5*self.step_size*grad_xreturnx,pdef_compute_hamiltonian(self,x:torch.Tensor,p:torch.Tensor)->torch.Tensor:"""Compute the Hamiltonian value. Args: x: Position tensor of batch_shape (n_samples, dim) p: Momentum tensor of batch_shape (n_samples, dim) Returns: Hamiltonian value of batch_shape (n_samples,) """energy=self.energy_function(x)ifself.mass_matrixisnotNone:kinetic=0.5*torch.sum(p*torch.matmul(p,self.mass_matrix),dim=1)else:kinetic=0.5*torch.sum(p*p,dim=1)returnenergy+kineticdefsample_step(self,x:torch.Tensor)->torch.Tensor:"""Perform one step of HMC. Args: x: Current samples of batch_shape (n_samples, dim) Returns: Updated samples of batch_shape (n_samples, dim) """# Sample initial momentump=torch.randn_like(x)# Compute initial Hamiltonianx_old,p_old=x.clone(),p.clone()h_old=self._compute_hamiltonian(x_old,p_old)# Leapfrog integrationx_new,p_new=x_old.clone(),p_old.clone()for_inrange(self.n_leapfrog_steps):x_new,p_new=self._leapfrog_step(x_new,p_new)# Metropolis-Hastings correctionh_new=self._compute_hamiltonian(x_new,p_new)accept_prob=torch.exp(h_old-h_new)accept=torch.rand_like(accept_prob)<accept_prob# Accept or rejectx_out=torch.where(accept.unsqueeze(1),x_new,x_old)returnx_outdefsample_chain(self,dim:int,n_steps:int,n_samples:int=1,**kwargs)->torch.Tensor:"""Generate samples using an HMC chain."""# Implementation similar to LangevinDynamics.sample_chainpassdefsample(self,n_samples:int,dim:int,n_steps:int=100,**kwargs)->torch.Tensor:"""Generate samples from the energy-based distribution."""returnself.sample_chain(dim=dim,n_steps=n_steps,n_samples=n_samples,**kwargs)
classMetropolisHastings(Sampler):"""Metropolis-Hastings sampler."""def__init__(self,energy_function:BaseEnergyFunction,proposal_scale:float=0.1):"""Initialize Metropolis-Hastings sampler. Args: energy_function: Energy function to sample from proposal_scale: Scale of proposal distribution """super().__init__(energy_function)self.proposal_scale=proposal_scaledefsample_step(self,x:torch.Tensor)->torch.Tensor:"""Perform one step of Metropolis-Hastings. Args: x: Current samples of batch_shape (n_samples, dim) Returns: Updated samples of batch_shape (n_samples, dim) """# Compute energy of current stateenergy_x=self.energy_function(x)# Propose new stateproposal=x+self.proposal_scale*torch.randn_like(x)# Compute energy of proposed stateenergy_proposal=self.energy_function(proposal)# Compute acceptance probabilityaccept_prob=torch.exp(energy_x-energy_proposal)accept=torch.rand_like(accept_prob)<accept_prob# Accept or rejectx_new=torch.where(accept.unsqueeze(1),proposal,x)returnx_newdefsample_chain(self,dim:int,n_steps:int,n_samples:int=1,**kwargs)->torch.Tensor:"""Generate samples using a Metropolis-Hastings chain."""# Implementation similar to LangevinDynamics.sample_chainpass
fromtorchebm.cudaimportlangevin_step_cudaclassCUDALangevinDynamics(LangevinDynamics):"""CUDA-optimized Langevin dynamics sampler."""defsample_step(self,x:torch.Tensor)->torch.Tensor:"""Perform one step of Langevin dynamics with CUDA optimization."""ifnottorch.cuda.is_available()ornotx.is_cuda:returnsuper().sample_step(x)returnlangevin_step_cuda(x,self.energy_function,self.step_size,self.noise_scale)
defbatch_sample_chain(sampler:Sampler,dim:int,n_steps:int,n_samples:int,batch_size:int=1000)->torch.Tensor:"""Sample in batches to avoid memory issues."""samples=[]foriinrange(0,n_samples,batch_size):batch_n=min(batch_size,n_samples-i)batch_samples=sampler.sample(dim=dim,n_steps=n_steps,n_samples=batch_n)samples.append(batch_samples)returntorch.cat(samples,dim=0)