Skip to content

WarmupScheduler

Methods and Attributes

Bases: BaseScheduler

Scheduler that combines linear warmup with another scheduler.

This scheduler implements a two-phase approach: first, it linearly increases the parameter value from a small initial value to the target value over a warmup period, then it follows the schedule defined by the main scheduler. Warmup is commonly used in deep learning to stabilize training in the initial phases.

Mathematical Formula

\[v(t) = \begin{cases} v_{init} + (v_{target} - v_{init}) \times \frac{t}{T_{warmup}}, & \text{if } t < T_{warmup} \\ \text{main\_scheduler}(t - T_{warmup}), & \text{if } t \geq T_{warmup} \end{cases}\]

where:

  • \(v_{init} = v_{target} \times \text{warmup\_init\_factor}\)
  • \(v_{target}\) is the main scheduler's start_value
  • \(T_{warmup}\) is warmup_steps
  • \(t\) is the current step count

Parameters:

Name Type Description Default
main_scheduler BaseScheduler

The scheduler to use after warmup.

required
warmup_steps int

Number of warmup steps.

required
warmup_init_factor float

Factor to determine initial warmup value. Defaults to 0.01.

0.01

Learning Rate Warmup + Cosine Annealing

main_scheduler = CosineScheduler(
    start_value=0.1, end_value=0.001, n_steps=1000
)
warmup_scheduler = WarmupScheduler(
    main_scheduler=main_scheduler,
    warmup_steps=100,
    warmup_init_factor=0.01
)

# First 100 steps: linear warmup from 0.001 to 0.1
# Next 1000 steps: cosine annealing from 0.1 to 0.001
for i in range(10):
    value = warmup_scheduler.step()
    print(f"Warmup step {i+1}: {value:.6f}")

MCMC Sampling with Warmup

decay_scheduler = ExponentialDecayScheduler(
    start_value=0.05, decay_rate=0.999, min_value=0.001
)
step_scheduler = WarmupScheduler(
    main_scheduler=decay_scheduler,
    warmup_steps=50,
    warmup_init_factor=0.1
)

sampler = LangevinDynamics(
    energy_function=energy_fn,
    step_size=step_scheduler
)

Noise Scale Warmup

1
2
3
4
5
6
7
8
linear_scheduler = LinearScheduler(
    start_value=1.0, end_value=0.01, n_steps=500
)
noise_scheduler = WarmupScheduler(
    main_scheduler=linear_scheduler,
    warmup_steps=25,
    warmup_init_factor=0.05
)
Source code in torchebm/core/base_scheduler.py
class WarmupScheduler(BaseScheduler):
    r"""
    Scheduler that combines linear warmup with another scheduler.

    This scheduler implements a two-phase approach: first, it linearly increases
    the parameter value from a small initial value to the target value over a
    warmup period, then it follows the schedule defined by the main scheduler.
    Warmup is commonly used in deep learning to stabilize training in the
    initial phases.

    !!! info "Mathematical Formula"
        $$v(t) = \begin{cases}
        v_{init} + (v_{target} - v_{init}) \times \frac{t}{T_{warmup}}, & \text{if } t < T_{warmup} \\
        \text{main\_scheduler}(t - T_{warmup}), & \text{if } t \geq T_{warmup}
        \end{cases}$$

        where:

        - \(v_{init} = v_{target} \times \text{warmup\_init\_factor}\)
        - \(v_{target}\) is the main scheduler's start_value
        - \(T_{warmup}\) is warmup_steps
        - \(t\) is the current step count

    Args:
        main_scheduler (BaseScheduler): The scheduler to use after warmup.
        warmup_steps (int): Number of warmup steps.
        warmup_init_factor (float, optional): Factor to determine initial warmup value.
            Defaults to 0.01.

    !!! example "Learning Rate Warmup + Cosine Annealing"
        ```python
        main_scheduler = CosineScheduler(
            start_value=0.1, end_value=0.001, n_steps=1000
        )
        warmup_scheduler = WarmupScheduler(
            main_scheduler=main_scheduler,
            warmup_steps=100,
            warmup_init_factor=0.01
        )

        # First 100 steps: linear warmup from 0.001 to 0.1
        # Next 1000 steps: cosine annealing from 0.1 to 0.001
        for i in range(10):
            value = warmup_scheduler.step()
            print(f"Warmup step {i+1}: {value:.6f}")
        ```

    !!! tip "MCMC Sampling with Warmup"
        ```python
        decay_scheduler = ExponentialDecayScheduler(
            start_value=0.05, decay_rate=0.999, min_value=0.001
        )
        step_scheduler = WarmupScheduler(
            main_scheduler=decay_scheduler,
            warmup_steps=50,
            warmup_init_factor=0.1
        )

        sampler = LangevinDynamics(
            energy_function=energy_fn,
            step_size=step_scheduler
        )
        ```

    !!! example "Noise Scale Warmup"
        ```python
        linear_scheduler = LinearScheduler(
            start_value=1.0, end_value=0.01, n_steps=500
        )
        noise_scheduler = WarmupScheduler(
            main_scheduler=linear_scheduler,
            warmup_steps=25,
            warmup_init_factor=0.05
        )
        ```
    """

    def __init__(
        self,
        main_scheduler: BaseScheduler,
        warmup_steps: int,
        warmup_init_factor: float = 0.01,
    ):
        r"""
        Initialize the warmup scheduler.

        Args:
            main_scheduler (BaseScheduler): The scheduler to use after warmup.
            warmup_steps (int): Number of warmup steps.
            warmup_init_factor (float, optional): Factor to determine initial warmup value.
                The initial value will be main_scheduler.start_value * warmup_init_factor.
                Defaults to 0.01.
        """
        # Initialize based on the main scheduler's initial value
        super().__init__(main_scheduler.start_value * warmup_init_factor)
        self.main_scheduler = main_scheduler
        self.warmup_steps = warmup_steps
        self.warmup_init_factor = warmup_init_factor
        self.target_value = main_scheduler.start_value  # Store the target after warmup

        # Reset main scheduler as warmup controls the initial phase
        self.main_scheduler.reset()

    def _compute_value(self) -> float:
        r"""
        Compute the value based on warmup phase or main scheduler.

        Returns:
            float: The parameter value from warmup or main scheduler.
        """
        if self.step_count < self.warmup_steps:
            # Linear warmup phase
            progress = self.step_count / self.warmup_steps
            return self.start_value + progress * (self.target_value - self.start_value)
        else:
            # Main scheduler phase
            # We need its value based on steps *after* warmup
            main_scheduler_step = self.step_count - self.warmup_steps
            # Temporarily set main scheduler state, get value, restore state
            original_step = self.main_scheduler.step_count
            original_value = self.main_scheduler.current_value
            self.main_scheduler.step_count = main_scheduler_step
            computed_main_value = self.main_scheduler._compute_value()
            # Restore state
            self.main_scheduler.step_count = original_step
            self.main_scheduler.current_value = original_value
            return computed_main_value

main_scheduler instance-attribute

main_scheduler = main_scheduler

warmup_steps instance-attribute

warmup_steps = warmup_steps

warmup_init_factor instance-attribute

warmup_init_factor = warmup_init_factor

target_value instance-attribute

target_value = start_value

_compute_value

_compute_value() -> float

Compute the value based on warmup phase or main scheduler.

Returns:

Name Type Description
float float

The parameter value from warmup or main scheduler.

Source code in torchebm/core/base_scheduler.py
def _compute_value(self) -> float:
    r"""
    Compute the value based on warmup phase or main scheduler.

    Returns:
        float: The parameter value from warmup or main scheduler.
    """
    if self.step_count < self.warmup_steps:
        # Linear warmup phase
        progress = self.step_count / self.warmup_steps
        return self.start_value + progress * (self.target_value - self.start_value)
    else:
        # Main scheduler phase
        # We need its value based on steps *after* warmup
        main_scheduler_step = self.step_count - self.warmup_steps
        # Temporarily set main scheduler state, get value, restore state
        original_step = self.main_scheduler.step_count
        original_value = self.main_scheduler.current_value
        self.main_scheduler.step_count = main_scheduler_step
        computed_main_value = self.main_scheduler._compute_value()
        # Restore state
        self.main_scheduler.step_count = original_step
        self.main_scheduler.current_value = original_value
        return computed_main_value