torchebm.datasets ¶
CheckerboardDataset ¶
Bases: BaseSyntheticDataset
Generates points in a 2D checkerboard pattern using rejection sampling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples | int | The target number of samples. | 2000 |
range_limit | float | Defines the square region | 4.0 |
noise | float | Small Gaussian noise added to the points. | 0.01 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility. | None |
Source code in torchebm/datasets/generators.py
CircleDataset ¶
Bases: BaseSyntheticDataset
Generates points sampled uniformly on a circle with noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples | int | The number of samples. | 2000 |
noise | float | The standard deviation of the Gaussian noise to add. | 0.05 |
radius | float | The radius of the circle. | 1.0 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility. | None |
Source code in torchebm/datasets/generators.py
EightGaussiansDataset ¶
Bases: BaseSyntheticDataset
Generates samples from the '8 Gaussians' mixture distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples | int | The total number of samples. | 2000 |
std | float | The standard deviation of each component. | 0.02 |
scale | float | A scaling factor for the centers of the Gaussians. | 2.0 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility. | None |
Source code in torchebm/datasets/generators.py
GaussianMixtureDataset ¶
Bases: BaseSyntheticDataset
Generates a 2D Gaussian mixture dataset with components arranged in a circle.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples | int | The total number of samples. | 2000 |
n_components | int | The number of Gaussian components (modes). | 8 |
std | float | The standard deviation of each Gaussian component. | 0.05 |
radius | float | The radius of the circle on which the centers lie. | 1.0 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility. | None |
Source code in torchebm/datasets/generators.py
GridDataset ¶
Bases: BaseSyntheticDataset
Generates points on a 2D grid.
Note: The total number of samples will be n_samples_per_dim ** 2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples_per_dim | int | The number of points along each dimension. | 10 |
range_limit | float | Defines the square region | 1.0 |
noise | float | The standard deviation of the Gaussian noise to add. | 0.01 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility (primarily affects noise). | None |
Source code in torchebm/datasets/generators.py
PinwheelDataset ¶
Bases: BaseSyntheticDataset
Generates the pinwheel dataset with curved blades.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples | int | The total number of samples. | 2000 |
n_classes | int | The number of 'blades' in the pinwheel. | 5 |
noise | float | The standard deviation of the final additive Cartesian noise. | 0.05 |
radial_scale | float | Controls the maximum radius/length of the blades. | 2.0 |
angular_scale | float | Controls the standard deviation of the angle noise (thickness). | 0.1 |
spiral_scale | float | Controls the tightness of the spiral. | 5.0 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility. | None |
Source code in torchebm/datasets/generators.py
SwissRollDataset ¶
Bases: BaseSyntheticDataset
Generates a 2D Swiss roll dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples | int | The number of samples. | 2000 |
noise | float | The standard deviation of the Gaussian noise to add. | 0.05 |
arclength | float | A factor controlling how many rolls the spiral has. | 3.0 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility. | None |
Source code in torchebm/datasets/generators.py
TwoMoonsDataset ¶
Bases: BaseSyntheticDataset
Generates the 'two moons' dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_samples | int | The total number of samples. | 2000 |
noise | float | The standard deviation of the Gaussian noise to add. | 0.05 |
device | Optional[Union[str, device]] | The device for the tensor. | None |
dtype | dtype | The data type for the tensor. | float32 |
seed | Optional[int] | A random seed for reproducibility. | None |