class MultiStepScheduler(BaseScheduler):
def __init__(self, initial_value: float, milestones: List[int], gamma: float = 0.1):
super().__init__(initial_value)
if not all(m > 0 for m in milestones):
raise ValueError("Milestone steps must be positive integers.")
if not all(
milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)
):
raise ValueError("Milestones must be strictly increasing.")
self.milestones = sorted(milestones)
self.gamma = gamma
def _compute_value(self) -> float:
power = sum(1 for m in self.milestones if self.step_count >= m)
return self.initial_value * (self.gamma**power)