Skip to content

RosenbrockEnergy

Methods and Attributes

Bases: EnergyFunction

Energy function for the Rosenbrock function.

Parameters:

Name Type Description Default
a float

Parameter a of the Rosenbrock function.

1.0
b float

Parameter b of the Rosenbrock function.

100.0

Methods:

Name Description
forward
gradient
Source code in torchebm/core/energy_function.py
class RosenbrockEnergy(EnergyFunction):
    """
    Energy function for the Rosenbrock function.

    Args:
        a (float): Parameter `a` of the Rosenbrock function.
        b (float): Parameter `b` of the Rosenbrock function.
    """

    def __init__(self, a: float = 1.0, b: float = 100.0):
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return (self.a - x[..., 0]) ** 2 + self.b * (x[..., 1] - x[..., 0] ** 2) ** 2

    def gradient(self, x: torch.Tensor) -> torch.Tensor:
        grad = torch.zeros_like(x)
        grad[..., 0] = -2 * (self.a - x[..., 0]) - 4 * self.b * x[..., 0] * (
            x[..., 1] - x[..., 0] ** 2
        )
        grad[..., 1] = 2 * self.b * (x[..., 1] - x[..., 0] ** 2)
        return grad

a instance-attribute

a = a

b instance-attribute

b = b

forward

forward(x: Tensor) -> torch.Tensor
Source code in torchebm/core/energy_function.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return (self.a - x[..., 0]) ** 2 + self.b * (x[..., 1] - x[..., 0] ** 2) ** 2

gradient

gradient(x: Tensor) -> torch.Tensor
Source code in torchebm/core/energy_function.py
def gradient(self, x: torch.Tensor) -> torch.Tensor:
    grad = torch.zeros_like(x)
    grad[..., 0] = -2 * (self.a - x[..., 0]) - 4 * self.b * x[..., 0] * (
        x[..., 1] - x[..., 0] ** 2
    )
    grad[..., 1] = 2 * self.b * (x[..., 1] - x[..., 0] ** 2)
    return grad