Data visualization is an essential tool for understanding, analyzing, and communicating the behavior of energy-based models. This guide covers various visualization techniques available in TorchEBM to help you gain insights into energy landscapes, sampling processes, and model performance.
Visualizing energy landscapes is crucial for understanding the structure of the probability distribution you're working with. TorchEBM provides utilities to create both 2D and 3D visualizations of models.
importtorchimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchebm.coreimportDoubleWellModelmodel=DoubleWellModel(barrier_height=2.0)x=np.linspace(-3,3,100)y=np.linspace(-3,3,100)X,Y=np.meshgrid(x,y)Z=np.zeros_like(X)foriinrange(X.shape[0]):forjinrange(X.shape[1]):point=torch.tensor([X[i,j],Y[i,j]],dtype=torch.float32).unsqueeze(0)Z[i,j]=model(point).item()fig=plt.figure(figsize=(10,8))ax=fig.add_subplot(111,projection='3d')surf=ax.plot_surface(X,Y,Z,cmap='viridis',alpha=0.8)ax.set_xlabel('x')ax.set_ylabel('y')ax.set_zlabel('Energy')ax.set_title('Double Well Energy Landscape')plt.colorbar(surf,ax=ax,shrink=0.5,aspect=5)plt.tight_layout()plt.show()
importtorchimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchebm.coreimportDoubleWellModelmodel=DoubleWellModel(barrier_height=2.0)grid_size=100plot_range=3.0x_coords=np.linspace(-plot_range,plot_range,grid_size)y_coords=np.linspace(-plot_range,plot_range,grid_size)X,Y=np.meshgrid(x_coords,y_coords)Z=np.zeros_like(X)foriinrange(X.shape[0]):forjinrange(X.shape[1]):point=torch.tensor([X[i,j],Y[i,j]],dtype=torch.float32).unsqueeze(0)Z[i,j]=model(point).item()log_prob_values=-Zlog_prob_values=log_prob_values-np.max(log_prob_values)prob_density=np.exp(log_prob_values)plt.figure(figsize=(10,8))contour=plt.contourf(X,Y,prob_density,levels=50,cmap='viridis')plt.colorbar(label='exp(-Energy) (unnormalized density)')plt.xlabel('X1')plt.ylabel('X2')plt.title('Double Well Probability Density')plt.grid(True,alpha=0.3)plt.tight_layout()plt.show()
importtorchimportnumpyasnpimportmatplotlib.pyplotaspltfromscipyimportstatsfromtorchebm.coreimportGaussianModelfromtorchebm.samplersimportLangevinDynamicsmean=torch.tensor([1.0,-1.0])cov=torch.tensor([[1.0,0.5],[0.5,1.0]])model=GaussianModel(mean=mean,cov=cov)sampler=LangevinDynamics(model=model,step_size=0.01)n_samples=5000burn_in=200x=torch.randn(n_samples,2)samples=sampler.sample(x=x,n_steps=1000,burn_in=burn_in,return_trajectory=False)samples_np=samples.numpy()mean_np=mean.numpy()cov_np=cov.numpy()x=np.linspace(-3,5,100)y=np.linspace(-5,3,100)X,Y=np.meshgrid(x,y)pos=np.dstack((X,Y))rv=stats.multivariate_normal(mean_np,cov_np)Z=rv.pdf(pos)fig=plt.figure(figsize=(15,5))ax1=fig.add_subplot(131)ax1.contourf(X,Y,Z,50,cmap='Blues')ax1.set_title('Ground Truth Density')ax1.set_xlabel('x')ax1.set_ylabel('y')ax2=fig.add_subplot(132)h=ax2.hist2d(samples_np[:,0],samples_np[:,1],bins=50,cmap='Reds',density=True)plt.colorbar(h[3],ax=ax2,label='Density')ax2.set_title('Sampled Distribution')ax2.set_xlabel('x')ax2.set_ylabel('y')ax3=fig.add_subplot(133)ax3.scatter(samples_np[:,0],samples_np[:,1],alpha=0.5,s=3)ax3.set_title('Sample Points')ax3.set_xlabel('x')ax3.set_ylabel('y')ax3.set_xlim(ax2.get_xlim())ax3.set_ylim(ax2.get_ylim())plt.tight_layout()plt.savefig('distribution_comparison_updated.png')plt.show()
Effective visualization is key to understanding and debugging energy-based models. TorchEBM provides tools for visualizing energy landscapes, sampling trajectories, and model performance. These visualizations can help you gain insights into your models and improve their design and performance.
Remember to adapt these examples to your specific needs - you might want to visualize higher-dimensional spaces using dimensionality reduction techniques, or create specialized plots for your particular application.