Skip to content

WarmupScheduler

Methods and Attributes

Bases: BaseScheduler

Source code in torchebm/core/base_scheduler.py
class WarmupScheduler(BaseScheduler):
    def __init__(
        self,
        main_scheduler: BaseScheduler,
        warmup_steps: int,
        warmup_init_factor: float = 0.01,
    ):
        # Initialize based on the main scheduler's initial value
        super().__init__(main_scheduler.initial_value * warmup_init_factor)
        self.main_scheduler = main_scheduler
        self.warmup_steps = warmup_steps
        self.warmup_init_factor = warmup_init_factor
        self.target_value = (
            main_scheduler.initial_value
        )  # Store the target after warmup

        # Reset main scheduler as warmup controls the initial phase
        self.main_scheduler.reset()

    def _compute_value(self) -> float:
        if self.step_count < self.warmup_steps:
            # Linear warmup
            progress = self.step_count / self.warmup_steps
            return self.initial_value + progress * (
                self.target_value - self.initial_value
            )
        else:
            # After warmup, step the main scheduler
            # We need its value based on steps *after* warmup
            main_scheduler_step = self.step_count - self.warmup_steps
            # Temporarily set main scheduler state, get value, restore state (a bit hacky)
            original_step = self.main_scheduler.step_count
            original_value = self.main_scheduler.current_value
            self.main_scheduler.step_count = main_scheduler_step
            computed_main_value = self.main_scheduler._compute_value()
            # Restore state
            self.main_scheduler.step_count = original_step
            self.main_scheduler.current_value = original_value
            return computed_main_value

main_scheduler instance-attribute

main_scheduler = main_scheduler

warmup_steps instance-attribute

warmup_steps = warmup_steps

warmup_init_factor instance-attribute

warmup_init_factor = warmup_init_factor

target_value instance-attribute

target_value = initial_value

_compute_value

_compute_value() -> float
Source code in torchebm/core/base_scheduler.py
def _compute_value(self) -> float:
    if self.step_count < self.warmup_steps:
        # Linear warmup
        progress = self.step_count / self.warmup_steps
        return self.initial_value + progress * (
            self.target_value - self.initial_value
        )
    else:
        # After warmup, step the main scheduler
        # We need its value based on steps *after* warmup
        main_scheduler_step = self.step_count - self.warmup_steps
        # Temporarily set main scheduler state, get value, restore state (a bit hacky)
        original_step = self.main_scheduler.step_count
        original_value = self.main_scheduler.current_value
        self.main_scheduler.step_count = main_scheduler_step
        computed_main_value = self.main_scheduler._compute_value()
        # Restore state
        self.main_scheduler.step_count = original_step
        self.main_scheduler.current_value = original_value
        return computed_main_value