Bases: BaseMetric
Base class for metrics that evaluate the quality of samples.
This class extends BaseMetric for metrics that specifically evaluate
sample quality by comparing with reference data.
Source code in torchebm/core/base_metric.py
| class SampleQualityMetric(BaseMetric):
"""
Base class for metrics that evaluate the quality of samples.
This class extends BaseMetric for metrics that specifically evaluate
sample quality by comparing with reference data.
"""
def __init__(self, name: str, lower_is_better: bool = True):
"""
Initialize the sample quality metric.
Args:
name (str): Name of the metric
lower_is_better (bool): Whether lower values indicate better performance
"""
super().__init__(name=name, lower_is_better=lower_is_better)
@abstractmethod
def __call__(
self, samples: torch.Tensor, reference: torch.Tensor, *args, **kwargs
) -> Dict[str, torch.Tensor]:
"""
Compute the metric value for samples against reference data.
Args:
samples: Tensor of samples to evaluate, batch_shape (n_samples, *dim)
reference: Tensor of reference data, batch_shape (n_reference, *dim)
Returns:
Dict[str, torch.Tensor]: Dictionary containing the computed metric values
"""
pass
|