Skip to content

Energy Function Examples

This section demonstrates the various energy functions available in TorchEBM and how to visualize them.

Basic Energy Landscapes

The landscape_2d.py example shows how to create and visualize basic energy functions:

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchebm.core import DoubleWellEnergy

# Create the energy function
energy_fn = DoubleWellEnergy(barrier_height=2.0)

# Create a grid for visualization
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = np.zeros_like(X)

# Compute energy values
for i in range(X.shape[0]):
    for j in range(X.shape[1]):
        point = torch.tensor([X[i, j], Y[i, j]], dtype=torch.float32).unsqueeze(0)
        Z[i, j] = energy_fn(point).item()

# Create 3D surface plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
surf = ax.plot_surface(X, Y, Z, cmap="viridis", alpha=0.8)

Multimodal Energy Functions

The multimodal.py example demonstrates more complex energy functions with multiple local minima:

class MultimodalEnergy:
    """
    A 2D energy function with multiple local minima to demonstrate sampling behavior.
    """
    def __init__(self, device=None, dtype=torch.float32):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = dtype

        # Define centers and weights for multiple Gaussian components
        self.centers = torch.tensor(
            [[-1.0, -1.0], [1.0, 1.0], [-0.5, 1.0], [1.0, -0.5]],
            device=self.device,
            dtype=self.dtype,
        )

        self.weights = torch.tensor(
            [1.0, 0.8, 0.6, 0.7], device=self.device, dtype=self.dtype
        )

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        # Calculate energy as negative log of mixture of Gaussians
        dists = torch.cdist(x, self.centers)
        energy = -torch.log(
            torch.sum(self.weights * torch.exp(-0.5 * dists.pow(2)), dim=-1)
        )
        return energy

Parametric Energy Functions

The parametric.py example shows how to create energy functions with adjustable parameters:

# Create a figure with multiple subplots
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

# Calculate energy landscapes for different barrier heights
barrier_heights = [0.5, 1.0, 2.0, 4.0]

for i, barrier_height in enumerate(barrier_heights):
    # Create energy function with the specified barrier height
    energy_fn = DoubleWellEnergy(barrier_height=barrier_height)

    # Compute energy values
    # ...

    # Create contour plot
    contour = axes[i].contourf(X, Y, Z, 50, cmap="viridis")
    fig.colorbar(contour, ax=axes[i], label="Energy")
    axes[i].set_title(f"Double Well Energy (Barrier Height = {barrier_height})")

Running the Examples

To run these examples:

# List available energy function examples
python examples/main.py --list

# Run a specific example
python examples/main.py core/energy_functions/landscape_2d
python examples/main.py core/energy_functions/multimodal
python examples/main.py core/energy_functions/parametric

Additional Resources

For more information on energy functions, see: