EightGaussiansDataset
Methods and Attributes¶
Bases: BaseSyntheticDataset
Generates samples from the specific '8 Gaussians' mixture.
This creates a specific arrangement of 8 Gaussian modes commonly used in the energy-based modeling literature.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_samples
|
int
|
Total number of samples. Default: 2000. |
2000
|
std
|
float
|
Standard deviation of each component. Default: 0.02. |
0.02
|
scale
|
float
|
Scaling factor for the centers (often 2). Default: 2.0. |
2.0
|
device
|
Optional[Union[str, device]]
|
Device for the tensor. |
None
|
dtype
|
dtype
|
Data type for the tensor. Default: torch.float32. |
float32
|
seed
|
Optional[int]
|
Random seed for reproducibility. |
None
|