Skip to content

GaussianEnergy

Methods and Attributes

Bases: EnergyFunction

Energy function for a Gaussian distribution. E(x) = 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ).

Parameters:

Name Type Description Default
mean Tensor

Mean vector (μ) of the Gaussian distribution.

required
cov Tensor

Covariance matrix (Σ) of the Gaussian distribution.

required
Source code in torchebm/core/energy_function.py
class GaussianEnergy(EnergyFunction):
    """
    Energy function for a Gaussian distribution. E(x) = 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ).

    Args:
        mean (torch.Tensor): Mean vector (μ) of the Gaussian distribution.
        cov (torch.Tensor): Covariance matrix (Σ) of the Gaussian distribution.
    """

    def __init__(self, mean: torch.Tensor, cov: torch.Tensor):
        super().__init__()
        if mean.ndim != 1:
            raise ValueError("Mean must be a 1D tensor.")
        if cov.ndim != 2 or cov.shape[0] != cov.shape[1]:
            raise ValueError("Covariance must be a 2D square matrix.")
        if mean.shape[0] != cov.shape[0]:
            raise ValueError(
                "Mean vector dimension must match covariance matrix dimension."
            )

        # Register mean and covariance inverse as buffers.
        # Buffers are part of the module's state (`state_dict`) and moved by `.to()`,
        # but are not considered parameters by optimizers.
        self.register_buffer("mean", mean)
        try:
            cov_inv = torch.inverse(cov)
            self.register_buffer("cov_inv", cov_inv)
        except RuntimeError as e:
            raise ValueError(
                f"Failed to invert covariance matrix: {e}. Ensure it is invertible."
            ) from e

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Computes the Gaussian energy: 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ)."""
        # Ensure x is compatible shape (batch_size, dim)
        if x.ndim == 1:  # Handle single sample case
            x = x.unsqueeze(0)
        if x.ndim != 2 or x.shape[1] != self.mean.shape[0]:
            raise ValueError(
                f"Input x expected shape (batch_size, {self.mean.shape[0]}), but got {x.shape}"
            )

        # Get mean and cov_inv on the same device as x
        # We don't change the dtype because gradient() already converted x to float32
        mean = self.mean.to(device=x.device)
        cov_inv = self.cov_inv.to(device=x.device)

        # Compute centered vectors
        # Important: use x directly without detaching or converting to maintain grad tracking
        delta = x - mean

        # Calculate energy
        # Use batch matrix multiplication for better numerical stability
        # We use einsum which maintains gradients through operations
        energy = 0.5 * torch.einsum("bi,ij,bj->b", delta, cov_inv, delta)

        return energy

forward

forward(x: Tensor) -> torch.Tensor

Computes the Gaussian energy: 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ).

Source code in torchebm/core/energy_function.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Computes the Gaussian energy: 0.5 * (x-μ)ᵀ Σ⁻¹ (x-μ)."""
    # Ensure x is compatible shape (batch_size, dim)
    if x.ndim == 1:  # Handle single sample case
        x = x.unsqueeze(0)
    if x.ndim != 2 or x.shape[1] != self.mean.shape[0]:
        raise ValueError(
            f"Input x expected shape (batch_size, {self.mean.shape[0]}), but got {x.shape}"
        )

    # Get mean and cov_inv on the same device as x
    # We don't change the dtype because gradient() already converted x to float32
    mean = self.mean.to(device=x.device)
    cov_inv = self.cov_inv.to(device=x.device)

    # Compute centered vectors
    # Important: use x directly without detaching or converting to maintain grad tracking
    delta = x - mean

    # Calculate energy
    # Use batch matrix multiplication for better numerical stability
    # We use einsum which maintains gradients through operations
    energy = 0.5 * torch.einsum("bi,ij,bj->b", delta, cov_inv, delta)

    return energy