Visualization Tools¶
This section demonstrates various visualization tools and techniques available in TorchEBM for visualizing energy functions and sampling processes.
Basic Visualizations¶
Contour Plots¶
The contour_plots.py
example demonstrates basic contour plots for energy functions:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchebm.core import DoubleWellEnergy
# Create the energy function
energy_fn = DoubleWellEnergy(barrier_height=2.0)
# Create a grid for visualization
x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)
Z = np.zeros_like(X)
# Compute energy values
for i in range(X.shape[0]):
for j in range(X.shape[1]):
point = torch.tensor([X[i, j], Y[i, j]], dtype=torch.float32).unsqueeze(0)
Z[i, j] = energy_fn(point).item()
# Create contour plot
plt.figure(figsize=(10, 8))
contour = plt.contourf(X, Y, Z, 50, cmap="viridis")
plt.colorbar(label="Energy")
plt.xlabel("x")
plt.ylabel("y")
plt.title("Double Well Energy Landscape")
Distribution Comparison¶
The distribution_comparison.py
example compares sampled distributions to their ground truth:
# Create figure with multiple plots
fig = plt.figure(figsize=(15, 5))
# Ground truth contour
ax1 = fig.add_subplot(131)
contour = ax1.contourf(X, Y, Z, 50, cmap="Blues")
fig.colorbar(contour, ax=ax1, label="Density")
ax1.set_title("Ground Truth Density")
# Sample density (using kernel density estimation)
ax2 = fig.add_subplot(132)
h = ax2.hist2d(samples_np[:, 0], samples_np[:, 1], bins=50, cmap="Reds", density=True)
fig.colorbar(h[3], ax=ax2, label="Density")
ax2.set_title("Sampled Distribution")
# Scatter plot of samples
ax3 = fig.add_subplot(133)
ax3.scatter(samples_np[:, 0], samples_np[:, 1], alpha=0.5, s=3)
ax3.set_title("Sample Points")
Advanced Visualizations¶
Trajectory Animation¶
The trajectory_animation.py
example visualizes sampling trajectories on energy landscapes:
# Extract trajectory coordinates
traj_x = trajectory[0, :, 0].numpy()
traj_y = trajectory[0, :, 1].numpy()
# Plot trajectory with colormap based on step number
plt.figure(figsize=(10, 8))
contour = plt.contourf(X, Y, Z, 50, cmap="viridis", alpha=0.7) # Energy landscape
points = plt.scatter(
traj_x, traj_y, c=np.arange(len(traj_x)), cmap="plasma", s=5, alpha=0.7
)
plt.colorbar(points, label="Sampling Step")
# Plot arrows to show direction of trajectory
step = 50 # Plot an arrow every 50 steps
plt.quiver(
traj_x[:-1:step],
traj_y[:-1:step],
traj_x[1::step] - traj_x[:-1:step],
traj_y[1::step] - traj_y[:-1:step],
scale_units="xy",
angles="xy",
scale=1,
color="red",
alpha=0.7,
)
Parallel Chains¶
The parallel_chains.py
example visualizes multiple sampling chains:
# Plot contour
plt.figure(figsize=(12, 10))
contour = plt.contourf(X, Y, Z, 50, cmap="viridis", alpha=0.7)
plt.colorbar(label="Energy")
# Plot each trajectory with a different color
colors = ["red", "blue", "green", "orange", "purple"]
for i in range(num_chains):
traj_x = trajectories[i, :, 0].numpy()
traj_y = trajectories[i, :, 1].numpy()
plt.plot(traj_x, traj_y, alpha=0.7, linewidth=1, c=colors[i], label=f"Chain {i+1}")
# Mark start and end points
plt.scatter(traj_x[0], traj_y[0], c="black", s=50, marker="o")
plt.scatter(traj_x[-1], traj_y[-1], c=colors[i], s=100, marker="*")
Energy Over Time¶
The energy_over_time.py
example tracks energy values during sampling:
# Track the trajectory and energy manually
trajectory = torch.zeros((1, n_steps, dim))
energy_values = torch.zeros(n_steps)
current_sample = initial_point.clone()
# Run the sampling steps and store each position and energy
for i in range(n_steps):
current_sample = sampler.langevin_step(
current_sample, torch.randn_like(current_sample)
)
trajectory[:, i, :] = current_sample.clone().detach()
energy_values[i] = energy_fn(current_sample).item()
# Plot energy evolution
plt.figure(figsize=(10, 6))
plt.plot(energy_values.numpy())
plt.xlabel("Step")
plt.ylabel("Energy")
plt.title("Energy Evolution During Sampling")
plt.grid(True, alpha=0.3)
Common Visualization Utilities¶
The utils.py
file provides common visualization functions that can be reused across examples:
def plot_2d_energy_landscape(
energy_fn,
title=None,
x_range=(-3, 3),
y_range=(-3, 3),
resolution=100,
device="cpu",
save_path=None
):
"""
Plot a 2D energy landscape as a contour plot.
Args:
energy_fn: The energy function to visualize
title: Optional title for the plot
x_range: Range for x-axis (min, max)
y_range: Range for y-axis (min, max)
resolution: Number of points along each axis
device: Device for tensor calculations
save_path: Optional path to save the figure
Returns:
The figure object
"""
# Implementation details...
def plot_sample_trajectories(
trajectories,
energy_fn=None,
title=None,
device="cpu",
save_path=None
):
"""
Plot sample trajectories on an energy landscape.
Args:
trajectories: Tensor of shape (n_samples, n_steps, dim)
energy_fn: Optional energy function for background
title: Optional title for the plot
device: Device for tensor calculations
save_path: Optional path to save the figure
Returns:
The figure object
"""
# Implementation details...
Running the Examples¶
To run these examples:
# List available visualization examples
python examples/main.py --list
# Run basic visualization examples
python examples/main.py visualization/basic/contour_plots
python examples/main.py visualization/basic/distribution_comparison
# Run advanced visualization examples
python examples/main.py visualization/advanced/trajectory_animation
python examples/main.py visualization/advanced/parallel_chains
python examples/main.py visualization/advanced/energy_over_time
Additional Resources¶
For more information on visualization tools, see: