BaseLoss
Methods and Attributes¶
Bases: DeviceMixin, Module, ABC
Abstract base class for loss functions used in energy-based models.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dtype | dtype | Data type for computations. | float32 |
device | Optional[Union[str, device]] | Device for computations. | None |
use_mixed_precision | bool | Whether to use mixed precision training. | False |
clip_value | Optional[float] | Optional value to clamp the loss. | None |
Source code in torchebm/core/base_loss.py
forward abstractmethod ¶
Computes the loss value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x | Tensor | Input data tensor from the target distribution. | required |
*args | Additional positional arguments. | () | |
**kwargs | Additional keyword arguments. | {} |
Returns:
| Type | Description |
|---|---|
Tensor | torch.Tensor: The computed scalar loss value. |