Bases: ABC
Abstract base class for parameter schedulers.
This class provides the foundation for all parameter scheduling strategies in TorchEBM.
Schedulers are used to dynamically adjust parameters such as step sizes, noise scales,
learning rates, and other hyperparameters during training or sampling processes.
The scheduler maintains an internal step counter and computes parameter values based
on the current step. Subclasses must implement the _compute_value
method to define
the specific scheduling strategy.
Mathematical Foundation
A scheduler defines a function \(f: \mathbb{N} \to \mathbb{R}\) that maps step numbers to parameter values:
\[v(t) = f(t)\]
where \(t\) is the current step count and \(v(t)\) is the parameter value at step \(t\).
Parameters:
Name |
Type |
Description |
Default |
start_value
|
float
|
Initial parameter value at step 0.
|
required
|
Creating a Custom Scheduler
| class CustomScheduler(BaseScheduler):
def __init__(self, start_value: float, factor: float):
super().__init__(start_value)
self.factor = factor
def _compute_value(self) -> float:
return self.start_value * (self.factor ** self.step_count)
scheduler = CustomScheduler(start_value=1.0, factor=0.9)
for i in range(5):
value = scheduler.step()
print(f"Step {i+1}: {value:.4f}")
|
State Management
| scheduler = ExponentialDecayScheduler(start_value=0.1, decay_rate=0.95)
# Take some steps
for _ in range(10):
scheduler.step()
# Save state
state = scheduler.state_dict()
# Reset and restore
scheduler.reset()
scheduler.load_state_dict(state)
|
Source code in torchebm/core/base_scheduler.py
| class BaseScheduler(ABC):
r"""
Abstract base class for parameter schedulers.
This class provides the foundation for all parameter scheduling strategies in TorchEBM.
Schedulers are used to dynamically adjust parameters such as step sizes, noise scales,
learning rates, and other hyperparameters during training or sampling processes.
The scheduler maintains an internal step counter and computes parameter values based
on the current step. Subclasses must implement the `_compute_value` method to define
the specific scheduling strategy.
!!! info "Mathematical Foundation"
A scheduler defines a function \(f: \mathbb{N} \to \mathbb{R}\) that maps step numbers to parameter values:
$$v(t) = f(t)$$
where \(t\) is the current step count and \(v(t)\) is the parameter value at step \(t\).
Args:
start_value (float): Initial parameter value at step 0.
Attributes:
start_value (float): The initial parameter value.
current_value (float): The current parameter value.
step_count (int): Number of steps taken since initialization or last reset.
!!! example "Creating a Custom Scheduler"
```python
class CustomScheduler(BaseScheduler):
def __init__(self, start_value: float, factor: float):
super().__init__(start_value)
self.factor = factor
def _compute_value(self) -> float:
return self.start_value * (self.factor ** self.step_count)
scheduler = CustomScheduler(start_value=1.0, factor=0.9)
for i in range(5):
value = scheduler.step()
print(f"Step {i+1}: {value:.4f}")
```
!!! tip "State Management"
```python
scheduler = ExponentialDecayScheduler(start_value=0.1, decay_rate=0.95)
# Take some steps
for _ in range(10):
scheduler.step()
# Save state
state = scheduler.state_dict()
# Reset and restore
scheduler.reset()
scheduler.load_state_dict(state)
```
"""
def __init__(self, start_value: float):
r"""
Initialize the base scheduler.
Args:
start_value (float): Initial parameter value. Must be a finite number.
Raises:
TypeError: If start_value is not a float or int.
"""
if not isinstance(start_value, (float, int)):
raise TypeError(
f"{type(self).__name__} received an invalid start_value of type "
f"{type(start_value).__name__}. Expected float or int."
)
self.start_value = float(start_value)
self.current_value = self.start_value
self.step_count = 0
@abstractmethod
def _compute_value(self) -> float:
r"""
Compute the parameter value for the current step count.
This method must be implemented by subclasses to define the specific
scheduling strategy. It should return the parameter value based on
the current `self.step_count`.
Returns:
float: The computed parameter value for the current step.
!!! warning "Implementation Note"
This method is called internally by `step()` after incrementing
the step counter. Subclasses should not call this method directly.
"""
pass
def step(self) -> float:
r"""
Advance the scheduler by one step and return the new parameter value.
This method increments the internal step counter and computes the new
parameter value using the scheduler's strategy. The computed value
becomes the new current value.
Returns:
float: The new parameter value after stepping.
!!! example "Basic Usage"
```python
scheduler = ExponentialDecayScheduler(start_value=1.0, decay_rate=0.9)
print(f"Initial: {scheduler.get_value()}") # 1.0
print(f"Step 1: {scheduler.step()}") # 0.9
print(f"Step 2: {scheduler.step()}") # 0.81
```
"""
self.step_count += 1
self.current_value = self._compute_value()
return self.current_value
def reset(self) -> None:
r"""
Reset the scheduler to its initial state.
This method resets both the step counter and current value to their
initial states, effectively restarting the scheduling process.
!!! example "Reset Example"
```python
scheduler = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=10)
for _ in range(5):
scheduler.step()
print(f"Before reset: step={scheduler.step_count}, value={scheduler.current_value}")
scheduler.reset()
print(f"After reset: step={scheduler.step_count}, value={scheduler.current_value}")
```
"""
self.current_value = self.start_value
self.step_count = 0
def get_value(self) -> float:
r"""
Get the current parameter value without advancing the scheduler.
This method returns the current parameter value without modifying
the scheduler's internal state. Use this when you need to query
the current value without stepping.
Returns:
float: The current parameter value.
!!! example "Query Current Value"
```python
scheduler = ConstantScheduler(start_value=0.5)
print(scheduler.get_value()) # 0.5
scheduler.step()
print(scheduler.get_value()) # 0.5 (still constant)
```
"""
return self.current_value
def state_dict(self) -> Dict[str, Any]:
r"""
Return the state of the scheduler as a dictionary.
This method returns a dictionary containing all the scheduler's internal
state, which can be used to save and restore the scheduler's state.
Returns:
Dict[str, Any]: Dictionary containing the scheduler's state.
!!! example "State Management"
```python
scheduler = CosineScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(50):
scheduler.step()
state = scheduler.state_dict()
print(state['step_count']) # 50
```
"""
return {key: value for key, value in self.__dict__.items()}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
r"""
Load the scheduler's state from a dictionary.
This method restores the scheduler's internal state from a dictionary
previously created by `state_dict()`. This is useful for resuming
training or sampling from a checkpoint.
Args:
state_dict (Dict[str, Any]): Dictionary containing the scheduler state.
Should be an object returned from a call to `state_dict()`.
!!! example "State Restoration"
```python
scheduler1 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(25):
scheduler1.step()
state = scheduler1.state_dict()
scheduler2 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
scheduler2.load_state_dict(state)
print(scheduler2.step_count) # 25
```
"""
self.__dict__.update(state_dict)
|
start_value
instance-attribute
start_value = float(start_value)
current_value
instance-attribute
current_value = start_value
step_count
instance-attribute
_compute_value
abstractmethod
_compute_value() -> float
Compute the parameter value for the current step count.
This method must be implemented by subclasses to define the specific
scheduling strategy. It should return the parameter value based on
the current self.step_count
.
Returns:
Name | Type |
Description |
float |
float
|
The computed parameter value for the current step.
|
Implementation Note
This method is called internally by step()
after incrementing
the step counter. Subclasses should not call this method directly.
Source code in torchebm/core/base_scheduler.py
| @abstractmethod
def _compute_value(self) -> float:
r"""
Compute the parameter value for the current step count.
This method must be implemented by subclasses to define the specific
scheduling strategy. It should return the parameter value based on
the current `self.step_count`.
Returns:
float: The computed parameter value for the current step.
!!! warning "Implementation Note"
This method is called internally by `step()` after incrementing
the step counter. Subclasses should not call this method directly.
"""
pass
|
step
Advance the scheduler by one step and return the new parameter value.
This method increments the internal step counter and computes the new
parameter value using the scheduler's strategy. The computed value
becomes the new current value.
Returns:
Name | Type |
Description |
float |
float
|
The new parameter value after stepping.
|
Basic Usage
| scheduler = ExponentialDecayScheduler(start_value=1.0, decay_rate=0.9)
print(f"Initial: {scheduler.get_value()}") # 1.0
print(f"Step 1: {scheduler.step()}") # 0.9
print(f"Step 2: {scheduler.step()}") # 0.81
|
Source code in torchebm/core/base_scheduler.py
| def step(self) -> float:
r"""
Advance the scheduler by one step and return the new parameter value.
This method increments the internal step counter and computes the new
parameter value using the scheduler's strategy. The computed value
becomes the new current value.
Returns:
float: The new parameter value after stepping.
!!! example "Basic Usage"
```python
scheduler = ExponentialDecayScheduler(start_value=1.0, decay_rate=0.9)
print(f"Initial: {scheduler.get_value()}") # 1.0
print(f"Step 1: {scheduler.step()}") # 0.9
print(f"Step 2: {scheduler.step()}") # 0.81
```
"""
self.step_count += 1
self.current_value = self._compute_value()
return self.current_value
|
reset
Reset the scheduler to its initial state.
This method resets both the step counter and current value to their
initial states, effectively restarting the scheduling process.
Reset Example
| scheduler = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=10)
for _ in range(5):
scheduler.step()
print(f"Before reset: step={scheduler.step_count}, value={scheduler.current_value}")
scheduler.reset()
print(f"After reset: step={scheduler.step_count}, value={scheduler.current_value}")
|
Source code in torchebm/core/base_scheduler.py
| def reset(self) -> None:
r"""
Reset the scheduler to its initial state.
This method resets both the step counter and current value to their
initial states, effectively restarting the scheduling process.
!!! example "Reset Example"
```python
scheduler = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=10)
for _ in range(5):
scheduler.step()
print(f"Before reset: step={scheduler.step_count}, value={scheduler.current_value}")
scheduler.reset()
print(f"After reset: step={scheduler.step_count}, value={scheduler.current_value}")
```
"""
self.current_value = self.start_value
self.step_count = 0
|
get_value
Get the current parameter value without advancing the scheduler.
This method returns the current parameter value without modifying
the scheduler's internal state. Use this when you need to query
the current value without stepping.
Returns:
Name | Type |
Description |
float |
float
|
The current parameter value.
|
Query Current Value
| scheduler = ConstantScheduler(start_value=0.5)
print(scheduler.get_value()) # 0.5
scheduler.step()
print(scheduler.get_value()) # 0.5 (still constant)
|
Source code in torchebm/core/base_scheduler.py
| def get_value(self) -> float:
r"""
Get the current parameter value without advancing the scheduler.
This method returns the current parameter value without modifying
the scheduler's internal state. Use this when you need to query
the current value without stepping.
Returns:
float: The current parameter value.
!!! example "Query Current Value"
```python
scheduler = ConstantScheduler(start_value=0.5)
print(scheduler.get_value()) # 0.5
scheduler.step()
print(scheduler.get_value()) # 0.5 (still constant)
```
"""
return self.current_value
|
state_dict
state_dict() -> Dict[str, Any]
Return the state of the scheduler as a dictionary.
This method returns a dictionary containing all the scheduler's internal
state, which can be used to save and restore the scheduler's state.
Returns:
Type |
Description |
Dict[str, Any]
|
Dict[str, Any]: Dictionary containing the scheduler's state.
|
State Management
| scheduler = CosineScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(50):
scheduler.step()
state = scheduler.state_dict()
print(state['step_count']) # 50
|
Source code in torchebm/core/base_scheduler.py
| def state_dict(self) -> Dict[str, Any]:
r"""
Return the state of the scheduler as a dictionary.
This method returns a dictionary containing all the scheduler's internal
state, which can be used to save and restore the scheduler's state.
Returns:
Dict[str, Any]: Dictionary containing the scheduler's state.
!!! example "State Management"
```python
scheduler = CosineScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(50):
scheduler.step()
state = scheduler.state_dict()
print(state['step_count']) # 50
```
"""
return {key: value for key, value in self.__dict__.items()}
|
load_state_dict
load_state_dict(state_dict: Dict[str, Any]) -> None
Load the scheduler's state from a dictionary.
This method restores the scheduler's internal state from a dictionary
previously created by state_dict()
. This is useful for resuming
training or sampling from a checkpoint.
Parameters:
Name |
Type |
Description |
Default |
state_dict
|
Dict[str, Any]
|
Dictionary containing the scheduler state.
Should be an object returned from a call to state_dict() .
|
required
|
State Restoration
| scheduler1 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(25):
scheduler1.step()
state = scheduler1.state_dict()
scheduler2 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
scheduler2.load_state_dict(state)
print(scheduler2.step_count) # 25
|
Source code in torchebm/core/base_scheduler.py
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
r"""
Load the scheduler's state from a dictionary.
This method restores the scheduler's internal state from a dictionary
previously created by `state_dict()`. This is useful for resuming
training or sampling from a checkpoint.
Args:
state_dict (Dict[str, Any]): Dictionary containing the scheduler state.
Should be an object returned from a call to `state_dict()`.
!!! example "State Restoration"
```python
scheduler1 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
for _ in range(25):
scheduler1.step()
state = scheduler1.state_dict()
scheduler2 = LinearScheduler(start_value=1.0, end_value=0.0, n_steps=100)
scheduler2.load_state_dict(state)
print(scheduler2.step_count) # 25
```
"""
self.__dict__.update(state_dict)
|