Visualizing sampling trajectories helps understand how different sampling algorithms explore the energy landscape. This example creates a multimodal energy function and visualizes multiple sampling chains as they traverse the landscape.
importtorchimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchebm.samplers.langevin_dynamicsimportLangevinDynamicsclassMultimodalEnergy:""" A 2D energy function with multiple local minima to demonstrate sampling behavior. """def__init__(self,device=None,dtype=torch.float32):self.device=deviceor("cuda"iftorch.cuda.is_available()else"cpu")self.dtype=dtype# Define centers and weights for multiple Gaussian componentsself.centers=torch.tensor([[-1.0,-1.0],[1.0,1.0],[-0.5,1.0],[1.0,-0.5]],device=self.device,dtype=self.dtype,)self.weights=torch.tensor([1.0,0.8,0.6,0.7],device=self.device,dtype=self.dtype)def__call__(self,x:torch.Tensor)->torch.Tensor:# Ensure input has correct dtype and batch_shapex=x.to(dtype=self.dtype)ifx.dim()==1:x=x.view(1,-1)# Calculate distance to each centerdists=torch.cdist(x,self.centers)# Calculate energy as negative log of mixture of Gaussiansenergy=-torch.log(torch.sum(self.weights*torch.exp(-0.5*dists.pow(2)),dim=-1))returnenergydefgradient(self,x:torch.Tensor)->torch.Tensor:# Ensure input has correct dtype and batch_shapex=x.to(dtype=self.dtype)ifx.dim()==1:x=x.view(1,-1)# Calculate distances and Gaussian componentsdiff=x.unsqueeze(1)-self.centersexp_terms=torch.exp(-0.5*torch.sum(diff.pow(2),dim=-1))weights_exp=self.weights*exp_terms# Calculate gradientnormalizer=torch.sum(weights_exp,dim=-1,keepdim=True)gradient=torch.sum(weights_exp.unsqueeze(-1)*diff/normalizer.unsqueeze(-1),dim=1)returngradientdefto(self,device):self.device=deviceself.centers=self.centers.to(device)self.weights=self.weights.to(device)returnself
defvisualize_energy_landscape_and_sampling():# Set up device and dtypedevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")dtype=torch.float32# Create energy functionenergy_fn=MultimodalEnergy(device=device,dtype=dtype)# Initialize the standard Langevin dynamics sampler from the librarysampler=LangevinDynamics(energy_function=energy_fn,step_size=0.01,noise_scale=0.1,device=device)# Create grid for energy landscape visualizationx=np.linspace(-3,3,100)y=np.linspace(-3,3,100)X,Y=np.meshgrid(x,y)# Calculate energy valuesgrid_points=torch.tensor(np.stack([X.flatten(),Y.flatten()],axis=1),device=device,dtype=dtype)energy_values=energy_fn(grid_points).cpu().numpy().reshape(X.shape)# Set up sampling parametersdim=2# 2D energy functionn_steps=200# Create figureplt.figure(figsize=(10,8))# Plot energy landscape with clear contourscontour=plt.contour(X,Y,energy_values,levels=20,cmap="viridis")plt.colorbar(contour,label="Energy")# Run multiple independent chains from different starting pointsn_chains=5# Define distinct colors for the chainscolors=plt.cm.tab10(np.linspace(0,1,n_chains))# Generate seeds for random starting positions to make chains start in different areasseeds=[42,123,456,789,999]fori,seedinenumerate(seeds):# Set the seed for reproducibilitytorch.manual_seed(seed)# Run one chain using the standard APItrajectory=sampler.sample(dim=dim,# 2D spacen_samples=1,# Single chainn_steps=n_steps,# Number of stepsreturn_trajectory=True# Return full trajectory)# Extract trajectory datatraj_np=trajectory.cpu().numpy().squeeze(0)# Remove n_samples dimension# Plot the trajectoryplt.plot(traj_np[:,0],traj_np[:,1],'o-',color=colors[i],alpha=0.6,markersize=3,label=f"Chain {i+1}")# Mark the start and end pointsplt.plot(traj_np[0,0],traj_np[0,1],'o',color=colors[i],markersize=8)plt.plot(traj_np[-1,0],traj_np[-1,1],'*',color=colors[i],markersize=10)# Add labels and titleplt.title("Energy Landscape and Langevin Dynamics Sampling Trajectories")plt.xlabel("x₁")plt.ylabel("x₂")plt.grid(True,alpha=0.3)plt.legend()
When you run this example, you'll see a contour plot of the energy landscape with multiple chains of Langevin dynamics samples overlaid. The visualization shows:
Energy landscape: Contour lines representing the multimodal energy function
Multiple sampling chains: Different colored trajectories starting from random initial points
Trajectory progression: You can see how samples move from high-energy regions to low-energy regions
Langevin Dynamics Sampling Trajectories
The key insights from this visualization:
Sampling chains are attracted to areas of low energy (high probability)
Chains can get trapped in local minima and have difficulty crossing energy barriers
The stochastic nature of Langevin dynamics helps chains occasionally escape local minima
Sampling efficiency depends on starting points and energy landscape geometry
# Initialize the samplersampler=LangevinDynamics(energy_function=my_energy_fn,step_size=0.01)# Run sampling with trajectory trackingtrajectory=sampler.sample(dim=2,# Dimension of the spacen_samples=10,# Number of parallel chainsn_steps=100,# Number of steps to runreturn_trajectory=True# Return the full trajectory rather than just final points)
When running the example, you'll see a visualization of the energy landscape with multiple sampling chains:
This visualization shows a multimodal energy landscape (contour lines) with five independent
Langevin dynamics sampling chains (colored trajectories). Each chain starts from a random position
(marked by a circle) and evolves through 200 steps (ending at the stars). The trajectories show how the
chains are attracted to the energy function's local minima. Note how some chains follow the gradient to the
nearest minimum, while others may explore multiple regions of the space.