Bases: BaseMetric
Base class for metrics that evaluate energy functions and samples.
This class extends BaseMetric for metrics that specifically work with
energy-based models by evaluating energy functions and sample quality.
Source code in torchebm/core/base_metric.py
| class EnergySampleMetric(BaseMetric):
"""
Base class for metrics that evaluate energy functions and samples.
This class extends BaseMetric for metrics that specifically work with
energy-based models by evaluating energy functions and sample quality.
"""
def __init__(self, name: str, lower_is_better: bool = True):
"""
Initialize the energy sample metric.
Args:
name (str): Name of the metric
lower_is_better (bool): Whether lower values indicate better performance
"""
super().__init__(name=name, lower_is_better=lower_is_better)
@abstractmethod
def __call__(
self, energy_fn: torch.nn.Module, samples: torch.Tensor, *args, **kwargs
) -> Dict[str, torch.Tensor]:
"""
Compute the metric value for energy functions and samples.
Args:
energy_fn: The energy function to evaluate
samples: Tensor of samples, batch_shape (n_samples, *dim)
Returns:
Dict[str, torch.Tensor]: Dictionary containing the computed metric values
"""
pass
|