Skip to content

SampleQualityMetric

Methods and Attributes

Bases: BaseMetric

Base class for metrics that evaluate the quality of samples.

This class extends BaseMetric for metrics that specifically evaluate sample quality by comparing with reference data.

Source code in torchebm/core/base_metric.py
class SampleQualityMetric(BaseMetric):
    """
    Base class for metrics that evaluate the quality of samples.

    This class extends BaseMetric for metrics that specifically evaluate
    sample quality by comparing with reference data.
    """

    def __init__(self, name: str, lower_is_better: bool = True):
        """
        Initialize the sample quality metric.

        Args:
            name (str): Name of the metric
            lower_is_better (bool): Whether lower values indicate better performance
        """
        super().__init__(name=name, lower_is_better=lower_is_better)

    @abstractmethod
    def __call__(
        self, samples: torch.Tensor, reference: torch.Tensor, *args, **kwargs
    ) -> Dict[str, torch.Tensor]:
        """
        Compute the metric value for samples against reference data.

        Args:
            samples: Tensor of samples to evaluate, batch_shape (n_samples, *dim)
            reference: Tensor of reference data, batch_shape (n_reference, *dim)

        Returns:
            Dict[str, torch.Tensor]: Dictionary containing the computed metric values
        """
        pass