BaseLoss
Methods and Attributes¶
Bases: Module
, ABC
Abstract base class for loss functions used in energy-based models.
This class builds on torch.nn.Module to allow loss functions to be part of PyTorch's computational graph and have trainable parameters if needed. It serves as the foundation for all loss functions in TorchEBM.
Inheriting from torch.nn.Module ensures compatibility with PyTorch's training infrastructure, including device placement, parameter management, and gradient computation.
Subclasses must implement the forward method to define the loss computation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dtype
|
dtype
|
Data type for computations |
float32
|
device
|
Optional[Union[str, device]]
|
Device for computations |
None
|
use_mixed_precision
|
bool
|
Whether to use mixed precision training (requires PyTorch 1.6+) |
False
|
Source code in torchebm/core/base_loss.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
forward
abstractmethod
¶
Compute the loss value given input data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor
|
Input data tensor, typically real samples from the target distribution. |
required |
*args
|
Additional positional arguments. |
()
|
|
**kwargs
|
Additional keyword arguments. |
{}
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The computed scalar loss value. |
Source code in torchebm/core/base_loss.py
to
¶
Move the loss function to the specified device and optionally change its dtype.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device
|
Union[str, device]
|
Target device for computations. |
required |
dtype
|
Optional[dtype]
|
Optional data type to convert to. |
None
|
Returns:
Type | Description |
---|---|
BaseLoss
|
The loss function instance moved to the specified device/dtype. |