BaseLoss
Methods and Attributes¶
Bases: Module
, ABC
Abstract base class for loss functions used in energy-based models.
This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's computational graph and have trainable parameters if needed. It serves as the foundation for all loss functions in TorchEBM.
Inheriting from torch.nn.Module ensures compatibility with PyTorch's training infrastructure, including device placement, parameter management, and gradient computation.
Subclasses must implement the forward method to define the loss computation.
Source code in torchebm/core/base_loss.py
forward
abstractmethod
¶
Compute the loss value given input data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input data tensor, typically real samples from the target distribution. |
required |
*args
|
Additional positional arguments. |
()
|
|
**kwargs
|
Additional keyword arguments. |
{}
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The computed scalar loss value. |
Source code in torchebm/core/base_loss.py
to
¶
Move the loss function to the specified device.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device
|
Target device for computations. |
required |
Returns:
Type | Description |
---|---|
The loss function instance moved to the specified device. |