Skip to content

BaseLoss

Methods and Attributes

Bases: Module, ABC

Abstract base class for loss functions used in energy-based models.

This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's computational graph and have trainable parameters if needed. It serves as the foundation for all loss functions in TorchEBM.

Inheriting from torch.nn.Module ensures compatibility with PyTorch's training infrastructure, including device placement, parameter management, and gradient computation.

Subclasses must implement the forward method to define the loss computation.

Source code in torchebm/core/base_loss.py
class BaseLoss(nn.Module, ABC):
    """
    Abstract base class for loss functions used in energy-based models.

    This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's
    computational graph and have trainable parameters if needed. It serves as the foundation
    for all loss functions in TorchEBM.

    Inheriting from torch.nn.Module ensures compatibility with PyTorch's training
    infrastructure, including device placement, parameter management, and gradient
    computation.

    Subclasses must implement the forward method to define the loss computation.
    """

    def __init__(self):
        """Initialize the base loss class."""
        super().__init__()

    @abstractmethod
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Compute the loss value given input data.

        Args:
            x: Input data tensor, typically real samples from the target distribution.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The computed scalar loss value.
        """
        pass

    def to(self, device):
        """
        Move the loss function to the specified device.

        Args:
            device: Target device for computations.

        Returns:
            The loss function instance moved to the specified device.
        """
        self.device = device
        return self

forward abstractmethod

forward(x: Tensor, *args, **kwargs) -> torch.Tensor

Compute the loss value given input data.

Parameters:

Name Type Description Default
x Tensor

Input data tensor, typically real samples from the target distribution.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

torch.Tensor: The computed scalar loss value.

Source code in torchebm/core/base_loss.py
@abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    """
    Compute the loss value given input data.

    Args:
        x: Input data tensor, typically real samples from the target distribution.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        torch.Tensor: The computed scalar loss value.
    """
    pass

to

to(device)

Move the loss function to the specified device.

Parameters:

Name Type Description Default
device

Target device for computations.

required

Returns:

Type Description

The loss function instance moved to the specified device.

Source code in torchebm/core/base_loss.py
def to(self, device):
    """
    Move the loss function to the specified device.

    Args:
        device: Target device for computations.

    Returns:
        The loss function instance moved to the specified device.
    """
    self.device = device
    return self