Bases: EnergyFunction
Energy function for a Gaussian distribution. E(x) = 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ).
Parameters:
Name |
Type |
Description |
Default |
mean
|
Tensor
|
Mean vector (μ) of the Gaussian distribution.
|
required
|
cov
|
Tensor
|
Covariance matrix (Σ) of the Gaussian distribution.
|
required
|
Source code in torchebm/core/energy_function.py
| class GaussianEnergy(EnergyFunction):
"""
Energy function for a Gaussian distribution. E(x) = 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ).
Args:
mean (torch.Tensor): Mean vector (μ) of the Gaussian distribution.
cov (torch.Tensor): Covariance matrix (Σ) of the Gaussian distribution.
"""
def __init__(self, mean: torch.Tensor, cov: torch.Tensor):
super().__init__()
if mean.ndim != 1:
raise ValueError("Mean must be a 1D tensor.")
if cov.ndim != 2 or cov.shape[0] != cov.shape[1]:
raise ValueError("Covariance must be a 2D square matrix.")
if mean.shape[0] != cov.shape[0]:
raise ValueError(
"Mean vector dimension must match covariance matrix dimension."
)
# Register mean and covariance inverse as buffers.
# Buffers are part of the module's state (`state_dict`) and moved by `.to()`,
# but are not considered parameters by optimizers.
self.register_buffer("mean", mean)
try:
cov_inv = torch.inverse(cov)
self.register_buffer("cov_inv", cov_inv)
except RuntimeError as e:
raise ValueError(
f"Failed to invert covariance matrix: {e}. Ensure it is invertible."
) from e
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes the Gaussian energy: 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ)."""
# Ensure x is compatible shape (batch_size, dim)
if x.ndim == 1: # Handle single sample case
x = x.unsqueeze(0)
if x.ndim != 2 or x.shape[1] != self.mean.shape[0]:
raise ValueError(
f"Input x expected shape (batch_size, {self.mean.shape[0]}), but got {x.shape}"
)
# Get mean and cov_inv on the same device as x
# We don't change the dtype because gradient() already converted x to float32
mean = self.mean.to(device=x.device)
cov_inv = self.cov_inv.to(device=x.device)
# Compute centered vectors
# Important: use x directly without detaching or converting to maintain grad tracking
delta = x - mean
# Calculate energy
# Use batch matrix multiplication for better numerical stability
# We use einsum which maintains gradients through operations
energy = 0.5 * torch.einsum("bi,ij,bj->b", delta, cov_inv, delta)
return energy
|
forward
forward(x: Tensor) -> torch.Tensor
Computes the Gaussian energy: 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ).
Source code in torchebm/core/energy_function.py
| def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes the Gaussian energy: 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ)."""
# Ensure x is compatible shape (batch_size, dim)
if x.ndim == 1: # Handle single sample case
x = x.unsqueeze(0)
if x.ndim != 2 or x.shape[1] != self.mean.shape[0]:
raise ValueError(
f"Input x expected shape (batch_size, {self.mean.shape[0]}), but got {x.shape}"
)
# Get mean and cov_inv on the same device as x
# We don't change the dtype because gradient() already converted x to float32
mean = self.mean.to(device=x.device)
cov_inv = self.cov_inv.to(device=x.device)
# Compute centered vectors
# Important: use x directly without detaching or converting to maintain grad tracking
delta = x - mean
# Calculate energy
# Use batch matrix multiplication for better numerical stability
# We use einsum which maintains gradients through operations
energy = 0.5 * torch.einsum("bi,ij,bj->b", delta, cov_inv, delta)
return energy
|