Skip to content

BaseScheduler

Methods and Attributes

Bases: ABC

Abstract base class for parameter schedulers.

This class provides the foundation for all parameter scheduling strategies in TorchEBM. Schedulers are used to dynamically adjust parameters such as step sizes, noise scales, learning rates, and other hyperparameters during training or sampling processes.

The scheduler maintains an internal step counter and computes parameter values based on the current step. Subclasses must implement the _compute_value method to define the specific scheduling strategy.

Mathematical Foundation

A scheduler defines a function \(f: \mathbb{N} \to \mathbb{R}\) that maps step numbers to parameter values:

\[v(t) = f(t)\]

where \(t\) is the current step count and \(v(t)\) is the parameter value at step \(t\).

Parameters:

Name Type Description Default
start_value float

Initial parameter value at step 0.

required

Creating a Custom Scheduler

class CustomScheduler(BaseScheduler):
    def __init__(self, start_value: float, factor: float):
        super().__init__(start_value)
        self.factor = factor

    def _compute_value(self) -> float:
        return self.start_value * (self.factor ** self.step_count)

scheduler = CustomScheduler(start_value=1.0, factor=0.9)
for i in range(5):
    value = scheduler.step()
    print(f"Step {i+1}: {value:.4f}")

State Management

scheduler = ExponentialDecayScheduler(start_value=0.1, decay_rate=0.95)
# Take some steps
for _ in range(10):
    scheduler.step()

# Save state
state = scheduler.state_dict()

# Reset and restore
scheduler.reset()
scheduler.load_state_dict(state)
Source code in torchebm/core/base_scheduler.py
class BaseScheduler(ABC):
    r"""
    Abstract base class for parameter schedulers.

    This class provides the foundation for all parameter scheduling strategies in TorchEBM.
    Schedulers are used to dynamically adjust parameters such as step sizes, noise scales,
    learning rates, and other hyperparameters during training or sampling processes.

    The scheduler maintains an internal step counter and computes parameter values based
    on the current step. Subclasses must implement the `_compute_value` method to define
    the specific scheduling strategy.

    !!! info "Mathematical Foundation"
        A scheduler defines a function \(f: \mathbb{N} \to \mathbb{R}\) that maps step numbers to parameter values:

        $$v(t) = f(t)$$

        where \(t\) is the current step count and \(v(t)\) is the parameter value at step \(t\).

    Args:
        start_value (float): Initial parameter value at step 0.

    Attributes:
        start_value (float): The initial parameter value.
        current_value (float): The current parameter value.
        step_count (int): Number of steps taken since initialization or last reset.

    !!! example "Creating a Custom Scheduler"
        ```python
        class CustomScheduler(BaseScheduler):
            def __init__(self, start_value: float, factor: float):
                super().__init__(start_value)
                self.factor = factor

            def _compute_value(self) -> float:
                return self.start_value * (self.factor ** self.step_count)

        scheduler = CustomScheduler(start_value=1.0, factor=0.9)
        for i in range(5):
            value = scheduler.step()
            print(f"Step {i+1}: {value:.4f}")
        ```

    !!! tip "State Management"
        ```python
        scheduler = ExponentialDecayScheduler(start_value=0.1, decay_rate=0.95)
        # Take some steps
        for _ in range(10):
            scheduler.step()

        # Save state
        state = scheduler.state_dict()

        # Reset and restore
        scheduler.reset()
        scheduler.load_state_dict(state)
        ```
    """

    def __init__(self, start_value: float):
        r"""
        Initialize the base scheduler.

        Args:
            start_value (float): Initial parameter value. Must be a finite number.

        Raises:
            TypeError: If start_value is not a float or int.
        """
        if not isinstance(start_value, (float, int)):
            raise TypeError(
                f"{type(self).__name__} received an invalid start_value of type "
                f"{type(start_value).__name__}. Expected float or int."
            )

        self.start_value = float(start_value)
        self.current_value = self.start_value
        self.step_count = 0

    @abstractmethod
    def _compute_value(self) -> float:
        r"""
        Compute the parameter value for the current step count.

        This method must be implemented by subclasses to define the specific
        scheduling strategy. It should return the parameter value based on
        the current `self.step_count`.

        Returns:
            float: The computed parameter value for the current step.

        !!! warning "Implementation Note"
            This method is called internally by `step()` after incrementing
            the step counter. Subclasses should not call this method directly.
        """
        pass

    def step(self) -> float:
        r"""
        Advance the scheduler by one step and return the new parameter value.

        This method increments the internal step counter and computes the new
        parameter value using the scheduler's strategy. The computed value
        becomes the new current value.

        Returns:
            float: The new parameter value after stepping.

        !!! example "Basic Usage"
            ```python
            scheduler = ExponentialDecayScheduler(start_value=1.0, decay_rate=0.9)
            print(f"Initial: {scheduler.get_value()}")  # 1.0
            print(f"Step 1: {scheduler.step()}")        # 0.9
            print(f"Step 2: {scheduler.step()}")        # 0.81
            ```
        """
        self.step_count += 1
        self.current_value = self._compute_value()
        return self.current_value

    def reset(self) -> None:
        r"""
        Reset the scheduler to its initial state.

        This method resets both the step counter and current value to their
        initial states, effectively restarting the scheduling process.

        !!! example "Reset Example"
            ```python
            scheduler = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=10)
            for _ in range(5):
                scheduler.step()
            print(f"Before reset: step={scheduler.step_count}, value={scheduler.current_value}")
            scheduler.reset()
            print(f"After reset: step={scheduler.step_count}, value={scheduler.current_value}")
            ```
        """
        self.current_value = self.start_value
        self.step_count = 0

    def get_value(self) -> float:
        r"""
        Get the current parameter value without advancing the scheduler.

        This method returns the current parameter value without modifying
        the scheduler's internal state. Use this when you need to query
        the current value without stepping.

        Returns:
            float: The current parameter value.

        !!! example "Query Current Value"
            ```python
            scheduler = ConstantScheduler(start_value=0.5)
            print(scheduler.get_value())  # 0.5
            scheduler.step()
            print(scheduler.get_value())  # 0.5 (still constant)
            ```
        """
        return self.current_value

    def state_dict(self) -> Dict[str, Any]:
        r"""
        Return the state of the scheduler as a dictionary.

        This method returns a dictionary containing all the scheduler's internal
        state, which can be used to save and restore the scheduler's state.

        Returns:
            Dict[str, Any]: Dictionary containing the scheduler's state.

        !!! example "State Management"
            ```python
            scheduler = CosineScheduler(start_value=1.0, end_value=0.0, n_steps=100)
            for _ in range(50):
                scheduler.step()
            state = scheduler.state_dict()
            print(state['step_count'])  # 50
            ```
        """
        return {key: value for key, value in self.__dict__.items()}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        r"""
        Load the scheduler's state from a dictionary.

        This method restores the scheduler's internal state from a dictionary
        previously created by `state_dict()`. This is useful for resuming
        training or sampling from a checkpoint.

        Args:
            state_dict (Dict[str, Any]): Dictionary containing the scheduler state.
                Should be an object returned from a call to `state_dict()`.

        !!! example "State Restoration"
            ```python
            scheduler1 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
            for _ in range(25):
                scheduler1.step()
            state = scheduler1.state_dict()

            scheduler2 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
            scheduler2.load_state_dict(state)
            print(scheduler2.step_count)  # 25
            ```
        """
        self.__dict__.update(state_dict)

start_value instance-attribute

start_value = float(start_value)

current_value instance-attribute

current_value = start_value

step_count instance-attribute

step_count = 0

_compute_value abstractmethod

_compute_value() -> float

Compute the parameter value for the current step count.

This method must be implemented by subclasses to define the specific scheduling strategy. It should return the parameter value based on the current self.step_count.

Returns:

Name Type Description
float float

The computed parameter value for the current step.

Implementation Note

This method is called internally by step() after incrementing the step counter. Subclasses should not call this method directly.

Source code in torchebm/core/base_scheduler.py
@abstractmethod
def _compute_value(self) -> float:
    r"""
    Compute the parameter value for the current step count.

    This method must be implemented by subclasses to define the specific
    scheduling strategy. It should return the parameter value based on
    the current `self.step_count`.

    Returns:
        float: The computed parameter value for the current step.

    !!! warning "Implementation Note"
        This method is called internally by `step()` after incrementing
        the step counter. Subclasses should not call this method directly.
    """
    pass

step

step() -> float

Advance the scheduler by one step and return the new parameter value.

This method increments the internal step counter and computes the new parameter value using the scheduler's strategy. The computed value becomes the new current value.

Returns:

Name Type Description
float float

The new parameter value after stepping.

Basic Usage

1
2
3
4
scheduler = ExponentialDecayScheduler(start_value=1.0, decay_rate=0.9)
print(f"Initial: {scheduler.get_value()}")  # 1.0
print(f"Step 1: {scheduler.step()}")        # 0.9
print(f"Step 2: {scheduler.step()}")        # 0.81
Source code in torchebm/core/base_scheduler.py
def step(self) -> float:
    r"""
    Advance the scheduler by one step and return the new parameter value.

    This method increments the internal step counter and computes the new
    parameter value using the scheduler's strategy. The computed value
    becomes the new current value.

    Returns:
        float: The new parameter value after stepping.

    !!! example "Basic Usage"
        ```python
        scheduler = ExponentialDecayScheduler(start_value=1.0, decay_rate=0.9)
        print(f"Initial: {scheduler.get_value()}")  # 1.0
        print(f"Step 1: {scheduler.step()}")        # 0.9
        print(f"Step 2: {scheduler.step()}")        # 0.81
        ```
    """
    self.step_count += 1
    self.current_value = self._compute_value()
    return self.current_value

reset

reset() -> None

Reset the scheduler to its initial state.

This method resets both the step counter and current value to their initial states, effectively restarting the scheduling process.

Reset Example

1
2
3
4
5
6
scheduler = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=10)
for _ in range(5):
    scheduler.step()
print(f"Before reset: step={scheduler.step_count}, value={scheduler.current_value}")
scheduler.reset()
print(f"After reset: step={scheduler.step_count}, value={scheduler.current_value}")
Source code in torchebm/core/base_scheduler.py
def reset(self) -> None:
    r"""
    Reset the scheduler to its initial state.

    This method resets both the step counter and current value to their
    initial states, effectively restarting the scheduling process.

    !!! example "Reset Example"
        ```python
        scheduler = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=10)
        for _ in range(5):
            scheduler.step()
        print(f"Before reset: step={scheduler.step_count}, value={scheduler.current_value}")
        scheduler.reset()
        print(f"After reset: step={scheduler.step_count}, value={scheduler.current_value}")
        ```
    """
    self.current_value = self.start_value
    self.step_count = 0

get_value

get_value() -> float

Get the current parameter value without advancing the scheduler.

This method returns the current parameter value without modifying the scheduler's internal state. Use this when you need to query the current value without stepping.

Returns:

Name Type Description
float float

The current parameter value.

Query Current Value

1
2
3
4
scheduler = ConstantScheduler(start_value=0.5)
print(scheduler.get_value())  # 0.5
scheduler.step()
print(scheduler.get_value())  # 0.5 (still constant)
Source code in torchebm/core/base_scheduler.py
def get_value(self) -> float:
    r"""
    Get the current parameter value without advancing the scheduler.

    This method returns the current parameter value without modifying
    the scheduler's internal state. Use this when you need to query
    the current value without stepping.

    Returns:
        float: The current parameter value.

    !!! example "Query Current Value"
        ```python
        scheduler = ConstantScheduler(start_value=0.5)
        print(scheduler.get_value())  # 0.5
        scheduler.step()
        print(scheduler.get_value())  # 0.5 (still constant)
        ```
    """
    return self.current_value

state_dict

state_dict() -> Dict[str, Any]

Return the state of the scheduler as a dictionary.

This method returns a dictionary containing all the scheduler's internal state, which can be used to save and restore the scheduler's state.

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: Dictionary containing the scheduler's state.

State Management

1
2
3
4
5
scheduler = CosineScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(50):
    scheduler.step()
state = scheduler.state_dict()
print(state['step_count'])  # 50
Source code in torchebm/core/base_scheduler.py
def state_dict(self) -> Dict[str, Any]:
    r"""
    Return the state of the scheduler as a dictionary.

    This method returns a dictionary containing all the scheduler's internal
    state, which can be used to save and restore the scheduler's state.

    Returns:
        Dict[str, Any]: Dictionary containing the scheduler's state.

    !!! example "State Management"
        ```python
        scheduler = CosineScheduler(start_value=1.0, end_value=0.0, n_steps=100)
        for _ in range(50):
            scheduler.step()
        state = scheduler.state_dict()
        print(state['step_count'])  # 50
        ```
    """
    return {key: value for key, value in self.__dict__.items()}

load_state_dict

load_state_dict(state_dict: Dict[str, Any]) -> None

Load the scheduler's state from a dictionary.

This method restores the scheduler's internal state from a dictionary previously created by state_dict(). This is useful for resuming training or sampling from a checkpoint.

Parameters:

Name Type Description Default
state_dict Dict[str, Any]

Dictionary containing the scheduler state. Should be an object returned from a call to state_dict().

required

State Restoration

1
2
3
4
5
6
7
8
scheduler1 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(25):
    scheduler1.step()
state = scheduler1.state_dict()

scheduler2 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
scheduler2.load_state_dict(state)
print(scheduler2.step_count)  # 25
Source code in torchebm/core/base_scheduler.py
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    r"""
    Load the scheduler's state from a dictionary.

    This method restores the scheduler's internal state from a dictionary
    previously created by `state_dict()`. This is useful for resuming
    training or sampling from a checkpoint.

    Args:
        state_dict (Dict[str, Any]): Dictionary containing the scheduler state.
            Should be an object returned from a call to `state_dict()`.

    !!! example "State Restoration"
        ```python
        scheduler1 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
        for _ in range(25):
            scheduler1.step()
        state = scheduler1.state_dict()

        scheduler2 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
        scheduler2.load_state_dict(state)
        print(scheduler2.step_count)  # 25
        ```
    """
    self.__dict__.update(state_dict)