importtorchimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchebm.coreimportDoubleWellEnergy# Create the energy functionenergy_fn=DoubleWellEnergy(barrier_height=2.0)# Create a grid for visualizationx=np.linspace(-3,3,100)y=np.linspace(-3,3,100)X,Y=np.meshgrid(x,y)Z=np.zeros_like(X)# Compute energy valuesforiinrange(X.shape[0]):forjinrange(X.shape[1]):point=torch.tensor([X[i,j],Y[i,j]],dtype=torch.float32).unsqueeze(0)Z[i,j]=energy_fn(point).item()# Create contour plotplt.figure(figsize=(10,8))contour=plt.contourf(X,Y,Z,50,cmap="viridis")plt.colorbar(label="Energy")plt.xlabel("x")plt.ylabel("y")plt.title("Double Well Energy Landscape")
# Create figure with multiple plotsfig=plt.figure(figsize=(15,5))# Ground truth contourax1=fig.add_subplot(131)contour=ax1.contourf(X,Y,Z,50,cmap="Blues")fig.colorbar(contour,ax=ax1,label="Density")ax1.set_title("Ground Truth Density")# Sample density (using kernel density estimation)ax2=fig.add_subplot(132)h=ax2.hist2d(samples_np[:,0],samples_np[:,1],bins=50,cmap="Reds",density=True)fig.colorbar(h[3],ax=ax2,label="Density")ax2.set_title("Sampled Distribution")# Scatter plot of samplesax3=fig.add_subplot(133)ax3.scatter(samples_np[:,0],samples_np[:,1],alpha=0.5,s=3)ax3.set_title("Sample Points")
# Extract trajectory coordinatestraj_x=trajectory[0,:,0].numpy()traj_y=trajectory[0,:,1].numpy()# Plot trajectory with colormap based on step numberplt.figure(figsize=(10,8))contour=plt.contourf(X,Y,Z,50,cmap="viridis",alpha=0.7)# Energy landscapepoints=plt.scatter(traj_x,traj_y,c=np.arange(len(traj_x)),cmap="plasma",s=5,alpha=0.7)plt.colorbar(points,label="Sampling Step")# Plot arrows to show direction of trajectorystep=50# Plot an arrow every 50 stepsplt.quiver(traj_x[:-1:step],traj_y[:-1:step],traj_x[1::step]-traj_x[:-1:step],traj_y[1::step]-traj_y[:-1:step],scale_units="xy",angles="xy",scale=1,color="red",alpha=0.7,)
# Plot contourplt.figure(figsize=(12,10))contour=plt.contourf(X,Y,Z,50,cmap="viridis",alpha=0.7)plt.colorbar(label="Energy")# Plot each trajectory with a different colorcolors=["red","blue","green","orange","purple"]foriinrange(num_chains):traj_x=trajectories[i,:,0].numpy()traj_y=trajectories[i,:,1].numpy()plt.plot(traj_x,traj_y,alpha=0.7,linewidth=1,c=colors[i],label=f"Chain {i+1}")# Mark start and end pointsplt.scatter(traj_x[0],traj_y[0],c="black",s=50,marker="o")plt.scatter(traj_x[-1],traj_y[-1],c=colors[i],s=100,marker="*")
# Track the trajectory and energy manuallytrajectory=torch.zeros((1,n_steps,dim))energy_values=torch.zeros(n_steps)current_sample=initial_point.clone()# Run the sampling steps and store each position and energyforiinrange(n_steps):current_sample=sampler.langevin_step(current_sample,torch.randn_like(current_sample))trajectory[:,i,:]=current_sample.clone().detach()energy_values[i]=energy_fn(current_sample).item()# Plot energy evolutionplt.figure(figsize=(10,6))plt.plot(energy_values.numpy())plt.xlabel("Step")plt.ylabel("Energy")plt.title("Energy Evolution During Sampling")plt.grid(True,alpha=0.3)
# List available visualization examplespythonexamples/main.py--list
# Run basic visualization examplespythonexamples/main.pyvisualization/basic/contour_plots
pythonexamples/main.pyvisualization/basic/distribution_comparison
# Run advanced visualization examplespythonexamples/main.pyvisualization/advanced/trajectory_animation
pythonexamples/main.pyvisualization/advanced/parallel_chains
pythonexamples/main.pyvisualization/advanced/energy_over_time