Skip to content

MultiStepScheduler

Methods and Attributes

Bases: BaseScheduler

Source code in torchebm/core/base_scheduler.py
class MultiStepScheduler(BaseScheduler):
    def __init__(self, initial_value: float, milestones: List[int], gamma: float = 0.1):
        super().__init__(initial_value)
        if not all(m > 0 for m in milestones):
            raise ValueError("Milestone steps must be positive integers.")
        if not all(
            milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)
        ):
            raise ValueError("Milestones must be strictly increasing.")
        self.milestones = sorted(milestones)
        self.gamma = gamma

    def _compute_value(self) -> float:
        power = sum(1 for m in self.milestones if self.step_count >= m)
        return self.initial_value * (self.gamma**power)

milestones instance-attribute

milestones = sorted(milestones)

gamma instance-attribute

gamma = gamma

_compute_value

_compute_value() -> float
Source code in torchebm/core/base_scheduler.py
def _compute_value(self) -> float:
    power = sum(1 for m in self.milestones if self.step_count >= m)
    return self.initial_value * (self.gamma**power)