Skip to content

BaseMetric

Methods and Attributes

Bases: ABC

Abstract base class for all evaluation metrics in torchebm.

This class defines the standard interface for implementing evaluation metrics for energy-based models. All metrics should inherit from this class and implement the required methods.

Source code in torchebm/core/base_metric.py
class BaseMetric(ABC):
    """
    Abstract base class for all evaluation metrics in torchebm.

    This class defines the standard interface for implementing evaluation metrics
    for energy-based models. All metrics should inherit from this class and implement
    the required methods.

    Attributes:
        name (str): The name of the metric
        lower_is_better (bool): Whether lower values indicate better performance
    """

    def __init__(self, name: str, lower_is_better: bool = True):
        """
        Initialize the base metric.

        Args:
            name (str): Name of the metric
            lower_is_better (bool): Whether lower values indicate better performance
        """
        self.name = name
        self.lower_is_better = lower_is_better
        self._device = None

    @property
    def device(self) -> Optional[torch.device]:
        """Returns the device associated with the metric."""
        return self._device

    def to(self, device: Union[str, torch.device]) -> "BaseMetric":
        """
        Move the metric to the specified device.

        Args:
            device: The device to move the metric to

        Returns:
            self: The metric instance moved to the device
        """
        if isinstance(device, str):
            device = torch.device(device)
        self._device = device
        return self

    @abstractmethod
    def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
        """
        Compute the metric value.

        Returns:
            Dict[str, torch.Tensor]: Dictionary containing the computed metric values
        """
        pass

    def compute(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
        """
        Alias for __call__ to match sklearn and other libraries' conventions.

        Returns:
            Dict[str, torch.Tensor]: Dictionary containing the computed metric values
        """
        return self(*args, **kwargs)

name instance-attribute

name = name

lower_is_better instance-attribute

lower_is_better = lower_is_better

_device instance-attribute

_device = None

device property

device: Optional[device]

Returns the device associated with the metric.

to

to(device: Union[str, device]) -> BaseMetric

Move the metric to the specified device.

Parameters:

Name Type Description Default
device Union[str, device]

The device to move the metric to

required

Returns:

Name Type Description
self BaseMetric

The metric instance moved to the device

Source code in torchebm/core/base_metric.py
def to(self, device: Union[str, torch.device]) -> "BaseMetric":
    """
    Move the metric to the specified device.

    Args:
        device: The device to move the metric to

    Returns:
        self: The metric instance moved to the device
    """
    if isinstance(device, str):
        device = torch.device(device)
    self._device = device
    return self

compute

compute(*args, **kwargs) -> Dict[str, torch.Tensor]

Alias for call to match sklearn and other libraries' conventions.

Returns:

Type Description
Dict[str, Tensor]

Dict[str, torch.Tensor]: Dictionary containing the computed metric values

Source code in torchebm/core/base_metric.py
def compute(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
    """
    Alias for __call__ to match sklearn and other libraries' conventions.

    Returns:
        Dict[str, torch.Tensor]: Dictionary containing the computed metric values
    """
    return self(*args, **kwargs)