Bases: ABC
Base class for parameter schedulers.
Args:
initial_value: Initial parameter value
Source code in torchebm/core/base_scheduler.py
| class BaseScheduler(ABC):
"""Base class for parameter schedulers.
Args:
initial_value: Initial parameter value
"""
def __init__(self, initial_value: float):
if not isinstance(initial_value, (float, int)):
raise TypeError(f"{type(self).__name__} received an invalid initial_value")
self.initial_value = initial_value
self.current_value = initial_value
self.step_count = 0
@abstractmethod
def _compute_value(self) -> float:
"""Compute the value for the current step count. To be implemented by subclasses."""
pass
def step(self) -> float:
"""Advance the scheduler by one step and return the new value."""
self.step_count += 1
self.current_value = self._compute_value()
return self.current_value
def reset(self) -> None:
"""Reset scheduler to initial state."""
self.current_value = self.initial_value
self.step_count = 0
def get_value(self) -> float:
"""Get current value without updating."""
return self.current_value
def state_dict(self) -> Dict[str, Any]:
"""Returns the state of the scheduler as a :class:`dict`."""
return {key: value for key, value in self.__dict__.items()}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
|
initial_value
instance-attribute
initial_value = initial_value
current_value
instance-attribute
current_value = initial_value
step_count
instance-attribute
_compute_value
abstractmethod
_compute_value() -> float
Compute the value for the current step count. To be implemented by subclasses.
Source code in torchebm/core/base_scheduler.py
| @abstractmethod
def _compute_value(self) -> float:
"""Compute the value for the current step count. To be implemented by subclasses."""
pass
|
step
Advance the scheduler by one step and return the new value.
Source code in torchebm/core/base_scheduler.py
| def step(self) -> float:
"""Advance the scheduler by one step and return the new value."""
self.step_count += 1
self.current_value = self._compute_value()
return self.current_value
|
reset
Reset scheduler to initial state.
Source code in torchebm/core/base_scheduler.py
| def reset(self) -> None:
"""Reset scheduler to initial state."""
self.current_value = self.initial_value
self.step_count = 0
|
get_value
Get current value without updating.
Source code in torchebm/core/base_scheduler.py
| def get_value(self) -> float:
"""Get current value without updating."""
return self.current_value
|
state_dict
state_dict() -> Dict[str, Any]
Returns the state of the scheduler as a :class:dict
.
Source code in torchebm/core/base_scheduler.py
| def state_dict(self) -> Dict[str, Any]:
"""Returns the state of the scheduler as a :class:`dict`."""
return {key: value for key, value in self.__dict__.items()}
|
load_state_dict
load_state_dict(state_dict: Dict[str, Any]) -> None
Loads the schedulers state.
Parameters:
Name |
Type |
Description |
Default |
state_dict
|
dict
|
scheduler state. Should be an object returned
from a call to :meth:state_dict .
|
required
|
Source code in torchebm/core/base_scheduler.py
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
|