Skip to content

EightGaussiansDataset

Methods and Attributes

Bases: BaseSyntheticDataset

Generates samples from the specific '8 Gaussians' mixture.

This creates a specific arrangement of 8 Gaussian modes commonly used in the energy-based modeling literature.

Parameters:

Name Type Description Default
n_samples int

Total number of samples. Default: 2000.

2000
std float

Standard deviation of each component. Default: 0.02.

0.02
scale float

Scaling factor for the centers (often 2). Default: 2.0.

2.0
device Optional[Union[str, device]]

Device for the tensor.

None
dtype dtype

Data type for the tensor. Default: torch.float32.

float32
seed Optional[int]

Random seed for reproducibility.

None
Source code in torchebm/datasets/generators.py
class EightGaussiansDataset(BaseSyntheticDataset):
    """
    Generates samples from the specific '8 Gaussians' mixture.

    This creates a specific arrangement of 8 Gaussian modes commonly used in the
    energy-based modeling literature.

    Args:
        n_samples (int): Total number of samples. Default: 2000.
        std (float): Standard deviation of each component. Default: 0.02.
        scale (float): Scaling factor for the centers (often 2). Default: 2.0.
        device (Optional[Union[str, torch.device]]): Device for the tensor.
        dtype (torch.dtype): Data type for the tensor. Default: torch.float32.
        seed (Optional[int]): Random seed for reproducibility.
    """

    def __init__(
        self,
        n_samples: int = 2000,
        std: float = 0.02,
        scale: float = 2.0,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        seed: Optional[int] = None,
    ):
        self.std = std
        self.scale = scale
        # Define the specific 8 centers
        centers_np = (
            np.array(
                [
                    (1, 0),
                    (-1, 0),
                    (0, 1),
                    (0, -1),
                    (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),
                    (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),
                ],
                dtype=np.float32,
            )
            * self.scale
        )
        self.centers_torch = torch.from_numpy(centers_np)
        self.n_components = 8
        super().__init__(n_samples=n_samples, device=device, dtype=dtype, seed=seed)

    def _generate_data(self) -> torch.Tensor:
        # Similar logic to GaussianMixtureDataset but with fixed centers
        centers_dev = self.centers_torch.to(device=self.device, dtype=self.dtype)

        data = torch.empty(self.n_samples, 2, device=self.device, dtype=self.dtype)
        samples_per_component = self.n_samples // self.n_components
        remainder = self.n_samples % self.n_components

        current_idx = 0
        for i in range(self.n_components):
            num = samples_per_component + (1 if i < remainder else 0)
            if num == 0:
                continue
            end_idx = current_idx + num
            noise = torch.randn(num, 2, device=self.device, dtype=self.dtype) * self.std
            data[current_idx:end_idx] = centers_dev[i] + noise
            current_idx = end_idx

        data = data[torch.randperm(self.n_samples, device=self.device)]
        return data

std instance-attribute

std = std

scale instance-attribute

scale = scale

centers_torch instance-attribute

centers_torch = from_numpy(centers_np)

n_components instance-attribute

n_components = 8

_generate_data

_generate_data() -> torch.Tensor
Source code in torchebm/datasets/generators.py
def _generate_data(self) -> torch.Tensor:
    # Similar logic to GaussianMixtureDataset but with fixed centers
    centers_dev = self.centers_torch.to(device=self.device, dtype=self.dtype)

    data = torch.empty(self.n_samples, 2, device=self.device, dtype=self.dtype)
    samples_per_component = self.n_samples // self.n_components
    remainder = self.n_samples % self.n_components

    current_idx = 0
    for i in range(self.n_components):
        num = samples_per_component + (1 if i < remainder else 0)
        if num == 0:
            continue
        end_idx = current_idx + num
        noise = torch.randn(num, 2, device=self.device, dtype=self.dtype) * self.std
        data[current_idx:end_idx] = centers_dev[i] + noise
        current_idx = end_idx

    data = data[torch.randperm(self.n_samples, device=self.device)]
    return data