Bases: EnergyFunction
Energy function for the Rosenbrock function.
Parameters:
Name | Type | Description | Default |
a | float | Parameter a of the Rosenbrock function. | 1.0 |
b | float | Parameter b of the Rosenbrock function. | 100.0 |
Methods:
Source code in torchebm/core/energy_function.py
| class RosenbrockEnergy(EnergyFunction):
"""
Energy function for the Rosenbrock function.
Args:
a (float): Parameter `a` of the Rosenbrock function.
b (float): Parameter `b` of the Rosenbrock function.
"""
def __init__(self, a: float = 1.0, b: float = 100.0):
super().__init__()
self.a = a
self.b = b
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (self.a - x[..., 0]) ** 2 + self.b * (x[..., 1] - x[..., 0] ** 2) ** 2
def gradient(self, x: torch.Tensor) -> torch.Tensor:
grad = torch.zeros_like(x)
grad[..., 0] = -2 * (self.a - x[..., 0]) - 4 * self.b * x[..., 0] * (
x[..., 1] - x[..., 0] ** 2
)
grad[..., 1] = 2 * self.b * (x[..., 1] - x[..., 0] ** 2)
return grad
|
forward
forward(x: Tensor) -> torch.Tensor
Source code in torchebm/core/energy_function.py
| def forward(self, x: torch.Tensor) -> torch.Tensor:
return (self.a - x[..., 0]) ** 2 + self.b * (x[..., 1] - x[..., 0] ** 2) ** 2
|
gradient
gradient(x: Tensor) -> torch.Tensor
Source code in torchebm/core/energy_function.py
| def gradient(self, x: torch.Tensor) -> torch.Tensor:
grad = torch.zeros_like(x)
grad[..., 0] = -2 * (self.a - x[..., 0]) - 4 * self.b * x[..., 0] * (
x[..., 1] - x[..., 0] ** 2
)
grad[..., 1] = 2 * self.b * (x[..., 1] - x[..., 0] ** 2)
return grad
|