Bases: Module
, ABC
Abstract base class for loss functions used in energy-based models.
This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's
computational graph and have trainable parameters if needed. It serves as the foundation
for all loss functions in TorchEBM.
Inheriting from torch.nn.Module ensures compatibility with PyTorch's training
infrastructure, including device placement, parameter management, and gradient
computation.
Subclasses must implement the forward method to define the loss computation.
Source code in torchebm/core/base_loss.py
| class BaseLoss(nn.Module, ABC):
"""
Abstract base class for loss functions used in energy-based models.
This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's
computational graph and have trainable parameters if needed. It serves as the foundation
for all loss functions in TorchEBM.
Inheriting from torch.nn.Module ensures compatibility with PyTorch's training
infrastructure, including device placement, parameter management, and gradient
computation.
Subclasses must implement the forward method to define the loss computation.
"""
def __init__(self):
"""Initialize the base loss class."""
super().__init__()
@abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Compute the loss value given input data.
Args:
x: Input data tensor, typically real samples from the target distribution.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: The computed scalar loss value.
"""
pass
def to(self, device: Union[str, torch.device]) -> "BaseLoss":
"""
Move the loss function to the specified device.
Args:
device: Target device for computations.
Returns:
The loss function instance moved to the specified device.
"""
self.device = device
return self
def __repr__(self):
"""Return a string representation of the loss function."""
return f"{self.__class__.__name__}()"
def __str__(self):
"""Return a string representation of the loss function."""
return self.__repr__()
def __call__(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Call the forward method of the loss function.
Args:
x: Input data tensor.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: The computed loss value.
"""
return self.forward(x, *args, **kwargs)
|
forward
abstractmethod
forward(x: Tensor, *args, **kwargs) -> torch.Tensor
Compute the loss value given input data.
Parameters:
Name |
Type |
Description |
Default |
x
|
Tensor
|
Input data tensor, typically real samples from the target distribution.
|
required
|
*args
|
|
Additional positional arguments.
|
()
|
**kwargs
|
|
Additional keyword arguments.
|
{}
|
Returns:
Type |
Description |
Tensor
|
torch.Tensor: The computed scalar loss value.
|
Source code in torchebm/core/base_loss.py
| @abstractmethod
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
Compute the loss value given input data.
Args:
x: Input data tensor, typically real samples from the target distribution.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: The computed scalar loss value.
"""
pass
|
to
to(device: Union[str, device]) -> BaseLoss
Move the loss function to the specified device.
Parameters:
Name |
Type |
Description |
Default |
device
|
Union[str, device]
|
Target device for computations.
|
required
|
Returns:
Type |
Description |
BaseLoss
|
The loss function instance moved to the specified device.
|
Source code in torchebm/core/base_loss.py
| def to(self, device: Union[str, torch.device]) -> "BaseLoss":
"""
Move the loss function to the specified device.
Args:
device: Target device for computations.
Returns:
The loss function instance moved to the specified device.
"""
self.device = device
return self
|