TorchEBM leverages CUDA to accelerate performance-critical operations. This guide explains the CUDA optimization strategies and how to implement new CUDA kernels.
# In setup.pyfromtorch.utils.cpp_extensionimportCUDAExtension,BuildExtensionsetup(name="torchebm",ext_modules=[CUDAExtension("torchebm.cuda.kernels",sources=["torchebm/cuda/bindings.cpp","torchebm/cuda/kernels/score_function.cu","torchebm/cuda/kernels/langevin_step.cu","torchebm/cuda/kernels/energy_kernels.cu",],include_dirs=["torchebm/cuda/kernels/include"],extra_compile_args={"cxx":["-O3"],"nvcc":["-O3"]})],cmdclass={"build_ext":BuildExtension})
defcuda_score(energy_fn,x,create_graph=False):"""CUDA-optimized score function computation. Args: energy_fn: Energy function x: Input tensor of batch_shape (batch_size, dim) create_graph: Whether to create gradient graph Returns: Score tensor of batch_shape (batch_size, dim) """# Check if energy function has custom CUDA implementationifhasattr(energy_fn,"cuda_score_impl")andtorch.cuda.is_available():returnenergy_fn.cuda_score_impl(x,create_graph)# Fall back to standard implementation for common energy functionsifisinstance(energy_fn,GaussianEnergy)andtorch.cuda.is_available():return_gaussian_score_cuda(energy_fn,x)# Fall back to autogradreturnscore_function(energy_fn,x,create_graph)
classCUDALangevinDynamics(LangevinDynamics):"""CUDA-optimized Langevin dynamics sampler."""def__init__(self,energy_function,step_size=0.01,noise_scale=1.0):super().__init__(energy_function,step_size,noise_scale)defsample_step(self,x):"""Perform one step of Langevin dynamics with CUDA optimization."""ifnottorch.cuda.is_available()ornotx.is_cuda:# Fall back to CPU implementationreturnsuper().sample_step(x)# Use optimized CUDA implementationreturnlangevin_step_cuda(x,self.energy_function,self.step_size,self.noise_scale)
__global__voidoptimized_kernel(...){// Declare shared memory for frequently accessed data__shared__floatshared_data[BLOCK_SIZE];// Load data into shared memoryshared_data[threadIdx.x]=global_data[blockIdx.x*blockDim.x+threadIdx.x];__syncthreads();// Use shared memory for computation// ...}
__global__voidoptimized_kernel(...){// Use local variables instead of arrays where possiblefloatx1,x2,x3,x4;// Process in chunks to reduce register usage// ...}
// Compute optimal block size based on problem dimensionsintcompute_block_size(intdim){// Power of 2 for better performanceif(dim<=32)return32;if(dim<=64)return64;if(dim<=128)return128;return256;}// Launch kernel with optimal configurationvoidlaunch_kernel(intbatch_size,intdim){intblock_size=compute_block_size(dim);intgrid_size=(batch_size+block_size-1)/block_size;my_kernel<<<grid_size,block_size>>>(/* args */);}
defmixed_precision_score(energy_fn,x):"""Compute score with mixed precision."""# Cast to half precision for computationx_half=x.half()x_half.requires_grad_(True)# Compute energy in half precisionwithtorch.cuda.amp.autocast():energy=energy_fn(x_half)# Compute gradient in full precisionscore=torch.autograd.grad(energy.sum(),x_half)[0].float()returnscore
defdistribute_sampling(energy_fn,n_samples,n_steps,device_ids):"""Distribute sampling across multiple GPUs."""# Distribute samples across devicessamples_per_device=n_samples//len(device_ids)results=[]fori,device_idinenumerate(device_ids):device=torch.device(f"cuda:{device_id}")# Create sampler on devicesampler=LangevinDynamics(energy_fn).to(device)# Compute samples for this devicesamples=sampler.sample(dim=energy_fn.dim,n_steps=n_steps,n_samples=samples_per_device)results.append(samples)# Gather results from all devicesreturntorch.cat(results,dim=0)
defparallel_score_computation(energy_fn,samples_list):"""Compute scores for multiple sample batches in parallel."""# Create streams for parallel executionstreams=[torch.cuda.Stream()for_inrange(len(samples_list))]# Start computation in separate streamsresults=[]fori,samplesinenumerate(samples_list):withtorch.cuda.stream(streams[i]):score=energy_fn.score(samples)results.append(score)# Synchronize streamsforstreaminstreams:stream.synchronize()returnresults
// In custom_kernel.cu#include<torch/extension.h>#include"common.cuh"// CUDA kernel__global__voidcustom_kernel(...){// Kernel implementation}// C++ binding functiontorch::Tensorcustom_kernel_cuda(...){// Binding implementation// ...returnresult;}// Register function for Python bindingPYBIND11_MODULE(TORCH_EXTENSION_NAME,m){m.def("custom_kernel",&custom_kernel_cuda,"Custom kernel implementation");}