Skip to content

LinearScheduler

Methods and Attributes

Bases: BaseScheduler

Scheduler with linear annealing.

Parameters:

Name Type Description Default
start_value float

Starting parameter value

required
end_value float

Target parameter value

required
n_steps int

Number of steps to reach final value

required
Source code in torchebm/core/base_scheduler.py
class LinearScheduler(BaseScheduler):
    """Scheduler with linear annealing.

    Args:
        start_value: Starting parameter value
        end_value: Target parameter value
        n_steps: Number of steps to reach final value
    """

    def __init__(self, start_value: float, end_value: float, n_steps: int):
        super().__init__(start_value)
        if n_steps <= 0:
            raise ValueError("n_steps must be positive")
        self.end_value = end_value
        self.n_steps = n_steps
        if n_steps > 0:
            self.step_size: float = (end_value - start_value) / n_steps
        else:
            self.step_size = 0.0  # Or handle n_steps=0 case appropriately

    def _compute_value(self) -> float:
        """Update value with linear change."""
        if self.step_count >= self.n_steps:
            self.current_value = self.end_value
        else:
            self.current_value = self.start_value + self.step_size * self.step_count
        return self.current_value

end_value instance-attribute

end_value = end_value

n_steps instance-attribute

n_steps = n_steps

step_size instance-attribute

step_size: float = end_value - start_value / n_steps

_compute_value

_compute_value() -> float

Update value with linear change.

Source code in torchebm/core/base_scheduler.py
def _compute_value(self) -> float:
    """Update value with linear change."""
    if self.step_count >= self.n_steps:
        self.current_value = self.end_value
    else:
        self.current_value = self.start_value + self.step_size * self.step_count
    return self.current_value