importtorchimportnumpyasnpimportmatplotlib.pyplotaspltfromtorchebm.coreimportDoubleWellEnergy# Create the energy functionenergy_fn=DoubleWellEnergy(barrier_height=2.0)# Create a grid for visualizationx=np.linspace(-3,3,100)y=np.linspace(-3,3,100)X,Y=np.meshgrid(x,y)Z=np.zeros_like(X)# Compute energy valuesforiinrange(X.shape[0]):forjinrange(X.shape[1]):point=torch.tensor([X[i,j],Y[i,j]],dtype=torch.float32).unsqueeze(0)Z[i,j]=energy_fn(point).item()# Create 3D surface plotfig=plt.figure(figsize=(10,8))ax=fig.add_subplot(111,projection="3d")surf=ax.plot_surface(X,Y,Z,cmap="viridis",alpha=0.8)
classMultimodalEnergy:""" A 2D energy function with multiple local minima to demonstrate sampling behavior. """def__init__(self,device=None,dtype=torch.float32):self.device=deviceor("cuda"iftorch.cuda.is_available()else"cpu")self.dtype=dtype# Define centers and weights for multiple Gaussian componentsself.centers=torch.tensor([[-1.0,-1.0],[1.0,1.0],[-0.5,1.0],[1.0,-0.5]],device=self.device,dtype=self.dtype,)self.weights=torch.tensor([1.0,0.8,0.6,0.7],device=self.device,dtype=self.dtype)def__call__(self,x:torch.Tensor)->torch.Tensor:# Calculate energy as negative log of mixture of Gaussiansdists=torch.cdist(x,self.centers)energy=-torch.log(torch.sum(self.weights*torch.exp(-0.5*dists.pow(2)),dim=-1))returnenergy
# Create a figure with multiple subplotsfig,axes=plt.subplots(2,2,figsize=(14,12))axes=axes.flatten()# Calculate energy landscapes for different barrier heightsbarrier_heights=[0.5,1.0,2.0,4.0]fori,barrier_heightinenumerate(barrier_heights):# Create energy function with the specified barrier heightenergy_fn=DoubleWellEnergy(barrier_height=barrier_height)# Compute energy values# ...# Create contour plotcontour=axes[i].contourf(X,Y,Z,50,cmap="viridis")fig.colorbar(contour,ax=axes[i],label="Energy")axes[i].set_title(f"Double Well Energy (Barrier Height = {barrier_height})")
# List available energy function examplespythonexamples/main.py--list
# Run a specific examplepythonexamples/main.pyenergy_functions/energy_functions/landscape_2d
pythonexamples/main.pyenergy_functions/energy_functions/multimodal
pythonexamples/main.pyenergy_functions/energy_functions/parametric