Skip to content

BaseScheduler

Methods and Attributes

Bases: ABC

Base class for parameter schedulers.

Args: initial_value: Initial parameter value

Source code in torchebm/core/base_scheduler.py
class BaseScheduler(ABC):
    """Base class for parameter schedulers.

    Args:
    initial_value: Initial parameter value
    """

    def __init__(self, initial_value: float):
        if not isinstance(initial_value, (float, int)):
            raise TypeError(f"{type(self).__name__} received an invalid initial_value")

        self.initial_value = initial_value
        self.current_value = initial_value
        self.step_count = 0

    @abstractmethod
    def _compute_value(self) -> float:
        """Compute the value for the current step count. To be implemented by subclasses."""
        pass

    def step(self) -> float:
        """Advance the scheduler by one step and return the new value."""
        self.step_count += 1
        self.current_value = self._compute_value()
        return self.current_value

    def reset(self) -> None:
        """Reset scheduler to initial state."""
        self.current_value = self.initial_value
        self.step_count = 0

    def get_value(self) -> float:
        """Get current value without updating."""
        return self.current_value

    def state_dict(self) -> Dict[str, Any]:
        """Returns the state of the scheduler as a :class:`dict`."""
        return {key: value for key, value in self.__dict__.items()}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

initial_value instance-attribute

initial_value = initial_value

current_value instance-attribute

current_value = initial_value

step_count instance-attribute

step_count = 0

_compute_value abstractmethod

_compute_value() -> float

Compute the value for the current step count. To be implemented by subclasses.

Source code in torchebm/core/base_scheduler.py
@abstractmethod
def _compute_value(self) -> float:
    """Compute the value for the current step count. To be implemented by subclasses."""
    pass

step

step() -> float

Advance the scheduler by one step and return the new value.

Source code in torchebm/core/base_scheduler.py
def step(self) -> float:
    """Advance the scheduler by one step and return the new value."""
    self.step_count += 1
    self.current_value = self._compute_value()
    return self.current_value

reset

reset() -> None

Reset scheduler to initial state.

Source code in torchebm/core/base_scheduler.py
def reset(self) -> None:
    """Reset scheduler to initial state."""
    self.current_value = self.initial_value
    self.step_count = 0

get_value

get_value() -> float

Get current value without updating.

Source code in torchebm/core/base_scheduler.py
def get_value(self) -> float:
    """Get current value without updating."""
    return self.current_value

state_dict

state_dict() -> Dict[str, Any]

Returns the state of the scheduler as a :class:dict.

Source code in torchebm/core/base_scheduler.py
def state_dict(self) -> Dict[str, Any]:
    """Returns the state of the scheduler as a :class:`dict`."""
    return {key: value for key, value in self.__dict__.items()}

load_state_dict

load_state_dict(state_dict: Dict[str, Any]) -> None

Loads the schedulers state.

Parameters:

Name Type Description Default
state_dict dict

scheduler state. Should be an object returned from a call to :meth:state_dict.

required
Source code in torchebm/core/base_scheduler.py
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    """Loads the schedulers state.

    Args:
        state_dict (dict): scheduler state. Should be an object returned
            from a call to :meth:`state_dict`.
    """
    self.__dict__.update(state_dict)