Skip to content

BaseLoss

Methods and Attributes

Bases: DeviceMixin, Module, ABC

Abstract base class for loss functions used in energy-based models.

Parameters:

Name Type Description Default
dtype dtype

Data type for computations.

float32
device Optional[Union[str, device]]

Device for computations.

None
use_mixed_precision bool

Whether to use mixed precision training.

False
clip_value Optional[float]

Optional value to clamp the loss.

None
Source code in torchebm/core/base_loss.py
class BaseLoss(DeviceMixin, nn.Module, ABC):
    """
    Abstract base class for loss functions used in energy-based models.

    Args:
        dtype (torch.dtype): Data type for computations.
        device (Optional[Union[str, torch.device]]): Device for computations.
        use_mixed_precision (bool): Whether to use mixed precision training.
        clip_value (Optional[float]): Optional value to clamp the loss.
    """

    def __init__(
        self,
        dtype: torch.dtype = torch.float32,
        device: Optional[Union[str, torch.device]] = None,
        use_mixed_precision: bool = False,
        clip_value: Optional[float] = None,
        *args: Any,
        **kwargs: Any,
    ):
        """Initialize the base loss class."""
        super().__init__(device=device, *args, **kwargs)

        # if isinstance(device, str):
        #     device = torch.device(device)
        self.dtype = dtype
        self.clip_value = clip_value
        self.setup_mixed_precision(use_mixed_precision)


    @abstractmethod
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Computes the loss value.

        Args:
            x (torch.Tensor): Input data tensor from the target distribution.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The computed scalar loss value.
        """
        pass

    def __repr__(self):
        """Return a string representation of the loss function."""
        return f"{self.__class__.__name__}()"

    def __str__(self):
        """Return a string representation of the loss function."""
        return self.__repr__()

    def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Calls the forward method of the loss function.

        Args:
            x (torch.Tensor): Input data tensor.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The computed loss value.
        """
        x = x.to(device=self.device, dtype=self.dtype)

        with self.autocast_context():
            loss = self.forward(x, *args, **kwargs)

        if self.clip_value:
            loss = torch.clamp(loss, -self.clip_value, self.clip_value)
        return loss

dtype instance-attribute

dtype = dtype

clip_value instance-attribute

clip_value = clip_value

forward abstractmethod

forward(x: Tensor, *args, **kwargs) -> torch.Tensor

Computes the loss value.

Parameters:

Name Type Description Default
x Tensor

Input data tensor from the target distribution.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

torch.Tensor: The computed scalar loss value.

Source code in torchebm/core/base_loss.py
@abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    """
    Computes the loss value.

    Args:
        x (torch.Tensor): Input data tensor from the target distribution.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        torch.Tensor: The computed scalar loss value.
    """
    pass