Bases: BaseModel
Energy-based model for the Rastrigin function.
Parameters:
| Name | Type | Description | Default |
a | float | The a parameter of the Rastrigin function. | 10.0 |
Source code in torchebm/core/base_model.py
| class RastriginModel(BaseModel):
r"""
Energy-based model for the Rastrigin function.
Args:
a (float): The `a` parameter of the Rastrigin function.
"""
def __init__(self, a: float = 10.0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.a = a
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Computes the Rastrigin energy."""
if x.ndim == 1:
x = x.unsqueeze(0)
n = x.shape[-1]
return self.a * n + torch.sum(
x**2 - self.a * torch.cos(2 * math.pi * x), dim=-1
)
|
forward
forward(x: Tensor) -> torch.Tensor
Computes the Rastrigin energy.
Source code in torchebm/core/base_model.py
| def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Computes the Rastrigin energy."""
if x.ndim == 1:
x = x.unsqueeze(0)
n = x.shape[-1]
return self.a * n + torch.sum(
x**2 - self.a * torch.cos(2 * math.pi * x), dim=-1
)
|