BaseSampler
Methods and Attributes¶
Bases: ABC
Base class for samplers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
energy_function
|
BaseEnergyFunction
|
Energy function to sample from. |
required |
dtype
|
dtype
|
Data type to use for the computations. |
float32
|
device
|
str
|
Device to run the computations on (e.g., "cpu" or "cuda"). |
None
|
Methods:
Name | Description |
---|---|
sample |
Run the sampling process. |
sample_chain |
Run the sampling process. |
_setup_diagnostics |
Initialize the diagnostics dictionary. |
to |
Move sampler to specified device. |
Source code in torchebm/core/base_sampler.py
sample
abstractmethod
¶
sample(x: Optional[Tensor] = None, dim: int = 10, n_steps: int = 100, n_samples: int = 1, thin: int = 1, return_trajectory: bool = False, return_diagnostics: bool = False, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]
Source code in torchebm/core/base_sampler.py
_setup_diagnostics
¶
Initialize the diagnostics dictionary.
.. deprecated:: 1.0
This method is deprecated and will be removed in a future version.