# Create energy function and move to appropriate devicedevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")energy_fn=GaussianEnergy(mean,cov).to(device)# Create sampler with the same devicesampler=LangevinDynamics(energy_fn,device=device)# Generate samples (automatically on the correct device)samples,_=sampler.sample(dim=2,n_steps=1000,n_samples=10000)
# Avoid creating new tensors in loopsforstepinrange(n_steps):# Bad: Creates new tensors each iterationx=x-step_size*energy_fn.gradient(x)+noise_scale*torch.randn_like(x)# Good: In-place operationsgrad=energy_fn.gradient(x)x.sub_(step_size*grad)x.add_(noise_scale*torch.randn_like(x))
# Standard PyTorch implementationdeflangevin_step_pytorch(x,energy_fn,step_size,noise_scale):grad=energy_fn.gradient(x)noise=torch.randn_like(x)*noise_scalereturnx-step_size*grad+noise# Using custom CUDA kernel when availablefromtorchebm.cudaimportlangevin_step_cudadeflangevin_step(x,energy_fn,step_size,noise_scale):ifx.is_cudaandtorch.cuda.is_available():returnlangevin_step_cuda(x,energy_fn,step_size,noise_scale)else:returnlangevin_step_pytorch(x,energy_fn,step_size,noise_scale)
# Optimize step size for Langevin dynamics# Rule of thumb: step_size ≈ O(d^(-1/3)) where d is dimensionstep_size=min(0.01,0.1*dim**(-1/3))# Noise scale should be sqrt(2 * step_size) for standard Langevinnoise_scale=np.sqrt(2*step_size)
# Optimize HMC parameters# Leapfrog steps should scale with dimensionn_leapfrog_steps=max(5,int(np.sqrt(dim)))# Step size should decrease with dimensionstep_size=min(0.01,0.05*dim**(-1/4))
# Distribution across GPUs using DataParallelimporttorch.nnasnnclassParallelSampler(nn.DataParallel):def__init__(self,sampler,device_ids=None):super().__init__(sampler,device_ids=device_ids)self.module=samplerdefsample_chain(self,dim,n_steps,n_samples):# Distribute samples across GPUsreturnself.forward(dim,n_steps,n_samples)# Create parallel samplerdevices=list(range(torch.cuda.device_count()))parallel_sampler=ParallelSampler(sampler,device_ids=devices)# Generate samples using all available GPUssamples=parallel_sampler.sample_chain(dim=100,n_steps=1000,n_samples=100000)
Performance optimization in TorchEBM involves careful attention to vectorization, GPU acceleration, memory management, and algorithm-specific tuning. By following these guidelines, you can achieve significant speedups in your energy-based modeling workflows.