Skip to content

GaussianMixtureDataset

Methods and Attributes

Bases: BaseSyntheticDataset

Generates samples from a 2D Gaussian mixture arranged uniformly in a circle.

Creates a mixture of Gaussian distributions with centers equally spaced on a circle. This distribution is useful for testing mode-seeking behavior in energy-based models.

Parameters:

Name Type Description Default
n_samples int

Total number of samples to generate.

2000
n_components int

Number of Gaussian components (modes). Default: 8.

8
std float

Standard deviation of each Gaussian component. Default: 0.05.

0.05
radius float

Radius of the circle on which the centers lie. Default: 1.0.

1.0
device Optional[Union[str, device]]

Device for the tensor.

None
dtype dtype

Data type for the tensor. Default: torch.float32.

float32
seed Optional[int]

Random seed for reproducibility.

None
Source code in torchebm/datasets/generators.py
class GaussianMixtureDataset(BaseSyntheticDataset):
    """
    Generates samples from a 2D Gaussian mixture arranged uniformly in a circle.

    Creates a mixture of Gaussian distributions with centers equally spaced on a circle.
    This distribution is useful for testing mode-seeking behavior in energy-based models.

    Args:
        n_samples (int): Total number of samples to generate.
        n_components (int): Number of Gaussian components (modes). Default: 8.
        std (float): Standard deviation of each Gaussian component. Default: 0.05.
        radius (float): Radius of the circle on which the centers lie. Default: 1.0.
        device (Optional[Union[str, torch.device]]): Device for the tensor.
        dtype (torch.dtype): Data type for the tensor. Default: torch.float32.
        seed (Optional[int]): Random seed for reproducibility.
    """

    def __init__(
        self,
        n_samples: int = 2000,
        n_components: int = 8,
        std: float = 0.05,
        radius: float = 1.0,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        seed: Optional[int] = None,
    ):
        if n_components <= 0:
            raise ValueError("n_components must be positive")
        if std < 0:
            raise ValueError("std must be non-negative")
        self.n_components = n_components
        self.std = std
        self.radius = radius
        super().__init__(n_samples=n_samples, device=device, dtype=dtype, seed=seed)

    def _generate_data(self) -> torch.Tensor:
        # Logic from make_gaussian_mixture
        thetas = np.linspace(0, 2 * np.pi, self.n_components, endpoint=False)
        centers = np.array(
            [(self.radius * np.cos(t), self.radius * np.sin(t)) for t in thetas],
            dtype=np.float32,
        )
        # Use torch directly for efficiency and device handling
        centers_torch = torch.from_numpy(centers)  # Keep on CPU for indexing efficiency

        data = torch.empty(self.n_samples, 2, device=self.device, dtype=self.dtype)
        samples_per_component = self.n_samples // self.n_components
        remainder = self.n_samples % self.n_components

        current_idx = 0
        for i in range(self.n_components):
            num = samples_per_component + (1 if i < remainder else 0)
            if num == 0:
                continue
            end_idx = current_idx + num
            # Generate noise directly on target device if possible
            noise = torch.randn(num, 2, device=self.device, dtype=self.dtype) * self.std
            component_center = centers_torch[i].to(device=self.device, dtype=self.dtype)
            data[current_idx:end_idx] = component_center + noise
            current_idx = end_idx

        # Shuffle the data to mix components
        data = data[
            torch.randperm(self.n_samples, device=self.device)
        ]  # Use device-aware permutation
        return data  # Return tensor directly

n_components instance-attribute

n_components = n_components

std instance-attribute

std = std

radius instance-attribute

radius = radius

_generate_data

_generate_data() -> torch.Tensor
Source code in torchebm/datasets/generators.py
def _generate_data(self) -> torch.Tensor:
    # Logic from make_gaussian_mixture
    thetas = np.linspace(0, 2 * np.pi, self.n_components, endpoint=False)
    centers = np.array(
        [(self.radius * np.cos(t), self.radius * np.sin(t)) for t in thetas],
        dtype=np.float32,
    )
    # Use torch directly for efficiency and device handling
    centers_torch = torch.from_numpy(centers)  # Keep on CPU for indexing efficiency

    data = torch.empty(self.n_samples, 2, device=self.device, dtype=self.dtype)
    samples_per_component = self.n_samples // self.n_components
    remainder = self.n_samples % self.n_components

    current_idx = 0
    for i in range(self.n_components):
        num = samples_per_component + (1 if i < remainder else 0)
        if num == 0:
            continue
        end_idx = current_idx + num
        # Generate noise directly on target device if possible
        noise = torch.randn(num, 2, device=self.device, dtype=self.dtype) * self.std
        component_center = centers_torch[i].to(device=self.device, dtype=self.dtype)
        data[current_idx:end_idx] = component_center + noise
        current_idx = end_idx

    # Shuffle the data to mix components
    data = data[
        torch.randperm(self.n_samples, device=self.device)
    ]  # Use device-aware permutation
    return data  # Return tensor directly