Skip to content

RastriginModel

Methods and Attributes

Bases: BaseModel

Energy-based model for the Rastrigin function.

Parameters:

Name Type Description Default
a float

The a parameter of the Rastrigin function.

10.0
Source code in torchebm/core/base_model.py
class RastriginModel(BaseModel):
    r"""
    Energy-based model for the Rastrigin function.

    Args:
        a (float): The `a` parameter of the Rastrigin function.
    """
    def __init__(self, a: float = 10.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.a = a

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Computes the Rastrigin energy."""
        if x.ndim == 1:
            x = x.unsqueeze(0)

        n = x.shape[-1]
        return self.a * n + torch.sum(
            x**2 - self.a * torch.cos(2 * math.pi * x), dim=-1
        )

a instance-attribute

a = a

forward

forward(x: Tensor) -> torch.Tensor

Computes the Rastrigin energy.

Source code in torchebm/core/base_model.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    r"""Computes the Rastrigin energy."""
    if x.ndim == 1:
        x = x.unsqueeze(0)

    n = x.shape[-1]
    return self.a * n + torch.sum(
        x**2 - self.a * torch.cos(2 * math.pi * x), dim=-1
    )