class WarmupScheduler(BaseScheduler):
def __init__(
self,
main_scheduler: BaseScheduler,
warmup_steps: int,
warmup_init_factor: float = 0.01,
):
# Initialize based on the main scheduler's initial value
super().__init__(main_scheduler.initial_value * warmup_init_factor)
self.main_scheduler = main_scheduler
self.warmup_steps = warmup_steps
self.warmup_init_factor = warmup_init_factor
self.target_value = (
main_scheduler.initial_value
) # Store the target after warmup
# Reset main scheduler as warmup controls the initial phase
self.main_scheduler.reset()
def _compute_value(self) -> float:
if self.step_count < self.warmup_steps:
# Linear warmup
progress = self.step_count / self.warmup_steps
return self.initial_value + progress * (
self.target_value - self.initial_value
)
else:
# After warmup, step the main scheduler
# We need its value based on steps *after* warmup
main_scheduler_step = self.step_count - self.warmup_steps
# Temporarily set main scheduler state, get value, restore state (a bit hacky)
original_step = self.main_scheduler.step_count
original_value = self.main_scheduler.current_value
self.main_scheduler.step_count = main_scheduler_step
computed_main_value = self.main_scheduler._compute_value()
# Restore state
self.main_scheduler.step_count = original_step
self.main_scheduler.current_value = original_value
return computed_main_value