Skip to content

EnergySampleMetric

Methods and Attributes

Bases: BaseMetric

Base class for metrics that evaluate energy functions and samples.

This class extends BaseMetric for metrics that specifically work with energy-based models by evaluating energy functions and sample quality.

Source code in torchebm/core/base_metric.py
class EnergySampleMetric(BaseMetric):
    """
    Base class for metrics that evaluate energy functions and samples.

    This class extends BaseMetric for metrics that specifically work with
    energy-based models by evaluating energy functions and sample quality.
    """

    def __init__(self, name: str, lower_is_better: bool = True):
        """
        Initialize the energy sample metric.

        Args:
            name (str): Name of the metric
            lower_is_better (bool): Whether lower values indicate better performance
        """
        super().__init__(name=name, lower_is_better=lower_is_better)

    @abstractmethod
    def __call__(
        self, energy_fn: torch.nn.Module, samples: torch.Tensor, *args, **kwargs
    ) -> Dict[str, torch.Tensor]:
        """
        Compute the metric value for energy functions and samples.

        Args:
            energy_fn: The energy function to evaluate
            samples: Tensor of samples, batch_shape (n_samples, *dim)

        Returns:
            Dict[str, torch.Tensor]: Dictionary containing the computed metric values
        """
        pass