Skip to content

BaseLoss

Methods and Attributes

Bases: Module, ABC

Abstract base class for loss functions used in energy-based models.

This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's computational graph and have trainable parameters if needed. It serves as the foundation for all loss functions in TorchEBM.

Inheriting from torch.nn.Module ensures compatibility with PyTorch's training infrastructure, including device placement, parameter management, and gradient computation.

Subclasses must implement the forward method to define the loss computation.

Parameters:

Name Type Description Default
dtype dtype

Data type for computations

float32
device Optional[Union[str, device]]

Device for computations

None
use_mixed_precision bool

Whether to use mixed precision training (requires PyTorch 1.6+)

False
Source code in torchebm/core/base_loss.py
class BaseLoss(nn.Module, ABC):
    """
    Abstract base class for loss functions used in energy-based models.

    This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's
    computational graph and have trainable parameters if needed. It serves as the foundation
    for all loss functions in TorchEBM.

    Inheriting from torch.nn.Module ensures compatibility with PyTorch's training
    infrastructure, including device placement, parameter management, and gradient
    computation.

    Subclasses must implement the forward method to define the loss computation.

    Args:
        dtype: Data type for computations
        device: Device for computations
        use_mixed_precision: Whether to use mixed precision training (requires PyTorch 1.6+)
    """

    def __init__(
        self,
        dtype: torch.dtype = torch.float32,
        device: Optional[Union[str, torch.device]] = None,
        use_mixed_precision: bool = False,
    ):
        """Initialize the base loss class."""
        super().__init__()
        if isinstance(device, str):
            device = torch.device(device)
        self.dtype = dtype
        self.device = device or (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        self.use_mixed_precision = use_mixed_precision

        # Check if mixed precision is available
        if self.use_mixed_precision:
            try:
                from torch.cuda.amp import autocast

                self.autocast_available = True
            except ImportError:
                warnings.warn(
                    "Mixed precision training requested but torch.cuda.amp not available. "
                    "Falling back to full precision. Requires PyTorch 1.6+.",
                    UserWarning,
                )
                self.use_mixed_precision = False
                self.autocast_available = False
        else:
            self.autocast_available = False

    @abstractmethod
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Compute the loss value given input data.

        Args:
            x: Input data tensor, typically real samples from the target distribution.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The computed scalar loss value.
        """
        pass

    def to(
        self, device: Union[str, torch.device], dtype: Optional[torch.dtype] = None
    ) -> "BaseLoss":
        """
        Move the loss function to the specified device and optionally change its dtype.

        Args:
            device: Target device for computations.
            dtype: Optional data type to convert to.

        Returns:
            The loss function instance moved to the specified device/dtype.
        """
        self.device = device
        if dtype is not None:
            self.dtype = dtype
        return super().to(device)

    def __repr__(self):
        """Return a string representation of the loss function."""
        return f"{self.__class__.__name__}()"

    def __str__(self):
        """Return a string representation of the loss function."""
        return self.__repr__()

    def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Call the forward method of the loss function.

        Args:
            x: Input data tensor.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The computed loss value.
        """
        # Ensure x is on the correct device and has the correct dtype
        x = x.to(device=self.device, dtype=self.dtype)

        # Apply mixed precision context if enabled
        if (
            hasattr(self, "use_mixed_precision")
            and self.use_mixed_precision
            and self.autocast_available
        ):
            from torch.cuda.amp import autocast

            with autocast():
                return self.forward(x, *args, **kwargs)
        else:
            return self.forward(x, *args, **kwargs)

dtype instance-attribute

dtype = dtype

device instance-attribute

device = device or device('cuda') if is_available() else device('cpu')

use_mixed_precision instance-attribute

use_mixed_precision = use_mixed_precision

autocast_available instance-attribute

autocast_available = True

forward abstractmethod

forward(x: Tensor, *args, **kwargs) -> torch.Tensor

Compute the loss value given input data.

Parameters:

Name Type Description Default
x Tensor

Input data tensor, typically real samples from the target distribution.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

torch.Tensor: The computed scalar loss value.

Source code in torchebm/core/base_loss.py
@abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    """
    Compute the loss value given input data.

    Args:
        x: Input data tensor, typically real samples from the target distribution.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        torch.Tensor: The computed scalar loss value.
    """
    pass

to

to(device: Union[str, device], dtype: Optional[dtype] = None) -> BaseLoss

Move the loss function to the specified device and optionally change its dtype.

Parameters:

Name Type Description Default
device Union[str, device]

Target device for computations.

required
dtype Optional[dtype]

Optional data type to convert to.

None

Returns:

Type Description
BaseLoss

The loss function instance moved to the specified device/dtype.

Source code in torchebm/core/base_loss.py
def to(
    self, device: Union[str, torch.device], dtype: Optional[torch.dtype] = None
) -> "BaseLoss":
    """
    Move the loss function to the specified device and optionally change its dtype.

    Args:
        device: Target device for computations.
        dtype: Optional data type to convert to.

    Returns:
        The loss function instance moved to the specified device/dtype.
    """
    self.device = device
    if dtype is not None:
        self.dtype = dtype
    return super().to(device)