Skip to content

BaseLoss

Methods and Attributes

Bases: DeviceMixin, Module, ABC

Abstract base class for loss functions used in energy-based models.

This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's computational graph and have trainable parameters if needed. It serves as the foundation for all loss functions in TorchEBM.

Inheriting from torch.nn.Module ensures compatibility with PyTorch's training infrastructure, including device placement, parameter management, and gradient computation.

Subclasses must implement the forward method to define the loss computation.

Parameters:

Name Type Description Default
dtype dtype

Data type for computations

float32
device Optional[Union[str, device]]

Device for computations

None
use_mixed_precision bool

Whether to use mixed precision training (requires PyTorch 1.6+)

False
Source code in torchebm/core/base_loss.py
class BaseLoss(DeviceMixin, nn.Module, ABC):
    """
    Abstract base class for loss functions used in energy-based models.

    This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's
    computational graph and have trainable parameters if needed. It serves as the foundation
    for all loss functions in TorchEBM.

    Inheriting from torch.nn.Module ensures compatibility with PyTorch's training
    infrastructure, including device placement, parameter management, and gradient
    computation.

    Subclasses must implement the forward method to define the loss computation.

    Args:
        dtype: Data type for computations
        device: Device for computations
        use_mixed_precision: Whether to use mixed precision training (requires PyTorch 1.6+)
    """

    def __init__(
        self,
        dtype: torch.dtype = torch.float32,
        device: Optional[Union[str, torch.device]] = None,
        use_mixed_precision: bool = False,
        clip_value: Optional[float] = None,
        *args: Any,
        **kwargs: Any,
    ):
        """Initialize the base loss class."""
        super().__init__(device=device, *args, **kwargs)

        # if isinstance(device, str):
        #     device = torch.device(device)
        self.dtype = dtype
        self.clip_value = clip_value
        self.use_mixed_precision = use_mixed_precision

        if self.use_mixed_precision:
            try:
                from torch.cuda.amp import autocast

                self.autocast_available = True
            except ImportError:
                warnings.warn(
                    "Mixed precision training requested but torch.cuda.amp not available. "
                    "Falling back to full precision. Requires PyTorch 1.6+.",
                    UserWarning,
                )
                self.use_mixed_precision = False
                self.autocast_available = False
        else:
            self.autocast_available = False

    # @property
    # def device(self) -> torch.device:
    #     """Returns the device associated with the module's parameters/buffers (if any)."""
    #     try:
    #         return next(self.parameters()).device
    #     except StopIteration:
    #         try:
    #             return next(self.buffers()).device
    #         except StopIteration:
    #             return self._device

    @abstractmethod
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Compute the loss value given input data.

        Args:
            x: Input data tensor, typically real samples from the target distribution.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The computed scalar loss value.
        """
        pass

    def __repr__(self):
        """Return a string representation of the loss function."""
        return f"{self.__class__.__name__}()"

    def __str__(self):
        """Return a string representation of the loss function."""
        return self.__repr__()

    def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Call the forward method of the loss function.

        Args:
            x: Input data tensor.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The computed loss value.
        """
        x = x.to(device=self.device, dtype=self.dtype)

        if (
            hasattr(self, "use_mixed_precision")
            and self.use_mixed_precision
            and self.autocast_available
        ):
            from torch.cuda.amp import autocast

            with autocast():
                loss = self.forward(x, *args, **kwargs)
        else:
            loss = self.forward(x, *args, **kwargs)

        if self.clip_value:
            loss = torch.clamp(loss, -self.clip_value, self.clip_value)
        return loss

dtype instance-attribute

dtype = dtype

clip_value instance-attribute

clip_value = clip_value

use_mixed_precision instance-attribute

use_mixed_precision = use_mixed_precision

autocast_available instance-attribute

autocast_available = True

forward abstractmethod

forward(x: Tensor, *args, **kwargs) -> torch.Tensor

Compute the loss value given input data.

Parameters:

Name Type Description Default
x Tensor

Input data tensor, typically real samples from the target distribution.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

torch.Tensor: The computed scalar loss value.

Source code in torchebm/core/base_loss.py
@abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    """
    Compute the loss value given input data.

    Args:
        x: Input data tensor, typically real samples from the target distribution.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        torch.Tensor: The computed scalar loss value.
    """
    pass