Skip to content

BaseSampler

Methods and Attributes

Bases: DeviceMixin, Module, ABC

Abstract base class for samplers.

Parameters:

Name Type Description Default
model Module

The model to sample from. For MCMC samplers, this is typically a BaseModel energy function; for learned samplers it may be any nn.Module.

required
dtype dtype

The data type for computations.

float32
device Optional[Union[str, device]]

The device for computations.

None
use_mixed_precision bool

Whether to use mixed-precision for sampling.

False
Source code in torchebm/core/base_sampler.py
class BaseSampler(DeviceMixin, nn.Module, ABC):
    """
    Abstract base class for samplers.

    Args:
        model (nn.Module): The model to sample from. For MCMC samplers, this is
            typically a `BaseModel` energy function; for learned samplers it may be
            any `nn.Module`.
        dtype (torch.dtype): The data type for computations.
        device (Optional[Union[str, torch.device]]): The device for computations.
        use_mixed_precision (bool): Whether to use mixed-precision for sampling.
    """

    def __init__(
        self,
        model: nn.Module,
        dtype: torch.dtype = torch.float32,
        device: Optional[Union[str, torch.device]] = None,
        use_mixed_precision: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__(device=device, dtype=dtype, *args, **kwargs)
        self.model = model
        self.dtype = dtype
        # if isinstance(device, str):
        #     device = torch.device(device)
        # self.device = device or torch.device(
        #     "cuda" if torch.cuda.is_available() else "cpu"
        # )
        self.setup_mixed_precision(use_mixed_precision)

        self.schedulers: Dict[str, BaseScheduler] = {}

        # Align child components using the mixin helper
        self.model = DeviceMixin.safe_to(
            self.model, device=self.device, dtype=self.dtype
        )

        # Ensure the energy function has matching precision settings
        if hasattr(self.model, "use_mixed_precision"):
            self.model.use_mixed_precision = self.use_mixed_precision

    @abstractmethod
    def sample(
        self,
        x: Optional[torch.Tensor] = None,
        dim: int = 10,
        n_steps: int = 100,
        n_samples: int = 1,
        thin: int = 1,
        return_trajectory: bool = False,
        return_diagnostics: bool = False,
        *args,
        **kwargs,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]:
        """
        Runs the sampling process.

        Args:
            x (Optional[torch.Tensor]): The initial state to start sampling from.
            dim (int): The dimension of the state space.
            n_steps (int): The number of MCMC steps to perform.
            n_samples (int): The number of samples to generate.
            thin (int): The thinning factor for samples (currently not supported).
            return_trajectory (bool): Whether to return the full trajectory of the samples.
            return_diagnostics (bool): Whether to return diagnostics of the sampling process.

        Returns:
            Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]:
                - A tensor of samples from the model.
                - If `return_diagnostics` is `True`, a tuple containing the samples
                  and a list of diagnostics dictionaries.
        """
        raise NotImplementedError

    def register_scheduler(self, name: str, scheduler: BaseScheduler) -> None:
        """
        Registers a parameter scheduler.

        Args:
            name (str): The name of the parameter to schedule.
            scheduler (BaseScheduler): The scheduler instance.
        """
        self.schedulers[name] = scheduler

    def get_schedulers(self) -> Dict[str, BaseScheduler]:
        """
        Gets all registered schedulers.

        Returns:
            Dict[str, BaseScheduler]: A dictionary mapping parameter names to their schedulers.
        """
        return self.schedulers

    def get_scheduled_value(self, name: str) -> float:
        """
        Gets the current value for a scheduled parameter.

        Args:
            name (str): The name of the scheduled parameter.

        Returns:
            float: The current value of the parameter.

        Raises:
            KeyError: If no scheduler is registered for the parameter.
        """
        if name not in self.schedulers:
            raise KeyError(f"No scheduler registered for parameter '{name}'")
        return self.schedulers[name].get_value()

    def step_schedulers(self) -> Dict[str, float]:
        """
        Advances all schedulers by one step.

        Returns:
            Dict[str, float]: A dictionary mapping parameter names to their updated values.
        """
        return {name: scheduler.step() for name, scheduler in self.schedulers.items()}

    def reset_schedulers(self) -> None:
        """Resets all schedulers to their initial state."""
        for scheduler in self.schedulers.values():
            scheduler.reset()

    # @abstractmethod
    def _setup_diagnostics(self) -> dict:
        """
        Initialize the diagnostics dictionary.

            .. deprecated:: 1.0
               This method is deprecated and will be removed in a future version.
        """
        return {
            "energies": torch.empty(0, device=self.device, dtype=self.dtype),
            "acceptance_rate": torch.tensor(0.0, device=self.device, dtype=self.dtype),
        }
        # raise NotImplementedError

    # def to(
    #     self, device: Union[str, torch.device], dtype: Optional[torch.dtype] = None
    # ) -> "BaseSampler":
    #     """
    #     Move sampler to the specified device and optionally change its dtype.
    #
    #     Args:
    #         device: Target device for computations
    #         dtype: Optional data type to convert to
    #
    #     Returns:
    #         The sampler instance moved to the specified device/dtype
    #     """
    #     if isinstance(device, str):
    #         device = torch.device(device)
    #
    #     self.device = device
    #
    #     if dtype is not None:
    #         self.dtype = dtype
    #
    #     # Update mixed precision availability if device changed
    #     if self.use_mixed_precision and not self.device.type.startswith("cuda"):
    #         warnings.warn(
    #             f"Mixed precision active but moving to {self.device}. "
    #             f"Mixed precision requires CUDA. Disabling mixed precision.",
    #             UserWarning,
    #         )
    #         self.use_mixed_precision = False
    #
    #     # Move energy function if it has a to method
    #     if hasattr(self.model, "to") and callable(
    #         getattr(self.model, "to")
    #     ):
    #         self.model = self.model.to(
    #             device=self.device, dtype=self.dtype
    #         )
    #
    #     return self

    def apply_mixed_precision(self, func):
        """
        A decorator to apply the mixed precision context to a method.

        Args:
            func: The function to wrap.

        Returns:
            The wrapped function.
        """

        def wrapper(*args, **kwargs):
            with self.autocast_context():
                return func(*args, **kwargs)

        return wrapper

    def to(self, *args, **kwargs):
        """Moves the sampler and its components to the specified device and/or dtype."""
        # Let DeviceMixin update internal state and parent class handle movement
        result = super().to(*args, **kwargs)
        # After move, make sure energy_function follows
        self.model = DeviceMixin.safe_to(
            self.model, device=self.device, dtype=self.dtype
        )
        return result

dtype instance-attribute

dtype = dtype

schedulers instance-attribute

schedulers: Dict[str, BaseScheduler] = {}

model instance-attribute

model = safe_to(model, device=device, dtype=dtype)

sample abstractmethod

sample(x: Optional[Tensor] = None, dim: int = 10, n_steps: int = 100, n_samples: int = 1, thin: int = 1, return_trajectory: bool = False, return_diagnostics: bool = False, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]

Runs the sampling process.

Parameters:

Name Type Description Default
x Optional[Tensor]

The initial state to start sampling from.

None
dim int

The dimension of the state space.

10
n_steps int

The number of MCMC steps to perform.

100
n_samples int

The number of samples to generate.

1
thin int

The thinning factor for samples (currently not supported).

1
return_trajectory bool

Whether to return the full trajectory of the samples.

False
return_diagnostics bool

Whether to return diagnostics of the sampling process.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, List[dict]]]

Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]: - A tensor of samples from the model. - If return_diagnostics is True, a tuple containing the samples and a list of diagnostics dictionaries.

Source code in torchebm/core/base_sampler.py
@abstractmethod
def sample(
    self,
    x: Optional[torch.Tensor] = None,
    dim: int = 10,
    n_steps: int = 100,
    n_samples: int = 1,
    thin: int = 1,
    return_trajectory: bool = False,
    return_diagnostics: bool = False,
    *args,
    **kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]:
    """
    Runs the sampling process.

    Args:
        x (Optional[torch.Tensor]): The initial state to start sampling from.
        dim (int): The dimension of the state space.
        n_steps (int): The number of MCMC steps to perform.
        n_samples (int): The number of samples to generate.
        thin (int): The thinning factor for samples (currently not supported).
        return_trajectory (bool): Whether to return the full trajectory of the samples.
        return_diagnostics (bool): Whether to return diagnostics of the sampling process.

    Returns:
        Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]:
            - A tensor of samples from the model.
            - If `return_diagnostics` is `True`, a tuple containing the samples
              and a list of diagnostics dictionaries.
    """
    raise NotImplementedError

register_scheduler

register_scheduler(name: str, scheduler: BaseScheduler) -> None

Registers a parameter scheduler.

Parameters:

Name Type Description Default
name str

The name of the parameter to schedule.

required
scheduler BaseScheduler

The scheduler instance.

required
Source code in torchebm/core/base_sampler.py
def register_scheduler(self, name: str, scheduler: BaseScheduler) -> None:
    """
    Registers a parameter scheduler.

    Args:
        name (str): The name of the parameter to schedule.
        scheduler (BaseScheduler): The scheduler instance.
    """
    self.schedulers[name] = scheduler

get_schedulers

get_schedulers() -> Dict[str, BaseScheduler]

Gets all registered schedulers.

Returns:

Type Description
Dict[str, BaseScheduler]

Dict[str, BaseScheduler]: A dictionary mapping parameter names to their schedulers.

Source code in torchebm/core/base_sampler.py
def get_schedulers(self) -> Dict[str, BaseScheduler]:
    """
    Gets all registered schedulers.

    Returns:
        Dict[str, BaseScheduler]: A dictionary mapping parameter names to their schedulers.
    """
    return self.schedulers

get_scheduled_value

get_scheduled_value(name: str) -> float

Gets the current value for a scheduled parameter.

Parameters:

Name Type Description Default
name str

The name of the scheduled parameter.

required

Returns:

Name Type Description
float float

The current value of the parameter.

Raises:

Type Description
KeyError

If no scheduler is registered for the parameter.

Source code in torchebm/core/base_sampler.py
def get_scheduled_value(self, name: str) -> float:
    """
    Gets the current value for a scheduled parameter.

    Args:
        name (str): The name of the scheduled parameter.

    Returns:
        float: The current value of the parameter.

    Raises:
        KeyError: If no scheduler is registered for the parameter.
    """
    if name not in self.schedulers:
        raise KeyError(f"No scheduler registered for parameter '{name}'")
    return self.schedulers[name].get_value()

step_schedulers

step_schedulers() -> Dict[str, float]

Advances all schedulers by one step.

Returns:

Type Description
Dict[str, float]

Dict[str, float]: A dictionary mapping parameter names to their updated values.

Source code in torchebm/core/base_sampler.py
def step_schedulers(self) -> Dict[str, float]:
    """
    Advances all schedulers by one step.

    Returns:
        Dict[str, float]: A dictionary mapping parameter names to their updated values.
    """
    return {name: scheduler.step() for name, scheduler in self.schedulers.items()}

reset_schedulers

reset_schedulers() -> None

Resets all schedulers to their initial state.

Source code in torchebm/core/base_sampler.py
def reset_schedulers(self) -> None:
    """Resets all schedulers to their initial state."""
    for scheduler in self.schedulers.values():
        scheduler.reset()

_setup_diagnostics

_setup_diagnostics() -> dict

Initialize the diagnostics dictionary.

1
2
.. deprecated:: 1.0
   This method is deprecated and will be removed in a future version.
Source code in torchebm/core/base_sampler.py
def _setup_diagnostics(self) -> dict:
    """
    Initialize the diagnostics dictionary.

        .. deprecated:: 1.0
           This method is deprecated and will be removed in a future version.
    """
    return {
        "energies": torch.empty(0, device=self.device, dtype=self.dtype),
        "acceptance_rate": torch.tensor(0.0, device=self.device, dtype=self.dtype),
    }

apply_mixed_precision

apply_mixed_precision(func)

A decorator to apply the mixed precision context to a method.

Parameters:

Name Type Description Default
func

The function to wrap.

required

Returns:

Type Description

The wrapped function.

Source code in torchebm/core/base_sampler.py
def apply_mixed_precision(self, func):
    """
    A decorator to apply the mixed precision context to a method.

    Args:
        func: The function to wrap.

    Returns:
        The wrapped function.
    """

    def wrapper(*args, **kwargs):
        with self.autocast_context():
            return func(*args, **kwargs)

    return wrapper

to

to(*args, **kwargs)

Moves the sampler and its components to the specified device and/or dtype.

Source code in torchebm/core/base_sampler.py
def to(self, *args, **kwargs):
    """Moves the sampler and its components to the specified device and/or dtype."""
    # Let DeviceMixin update internal state and parent class handle movement
    result = super().to(*args, **kwargs)
    # After move, make sure energy_function follows
    self.model = DeviceMixin.safe_to(
        self.model, device=self.device, dtype=self.dtype
    )
    return result