Skip to content

PinwheelDataset

Methods and Attributes

Bases: BaseSyntheticDataset

Generates the pinwheel dataset with curved blades.

Parameters:

Name Type Description Default
n_samples int

Total number of samples. Default: 2000.

2000
n_classes int

Number of 'blades'. Default: 5.

5
noise float

Std dev of final additive Cartesian noise. Default: 0.05.

0.05
radial_scale float

Controls max radius/length of blades. Default: 2.0.

2.0
angular_scale float

Controls std dev of angle noise (thickness). Default: 0.1.

0.1
spiral_scale float

Controls spiral tightness. Default: 5.0.

5.0
device Optional[Union[str, device]]

Device for the tensor.

None
dtype dtype

Data type for the tensor. Default: torch.float32.

float32
seed Optional[int]

Random seed for reproducibility.

None
Source code in torchebm/datasets/generators.py
class PinwheelDataset(BaseSyntheticDataset):
    """
    Generates the pinwheel dataset with curved blades.

    Args:
        n_samples (int): Total number of samples. Default: 2000.
        n_classes (int): Number of 'blades'. Default: 5.
        noise (float): Std dev of final additive Cartesian noise. Default: 0.05.
        radial_scale (float): Controls max radius/length of blades. Default: 2.0.
        angular_scale (float): Controls std dev of angle noise (thickness). Default: 0.1.
        spiral_scale (float): Controls spiral tightness. Default: 5.0.
        device (Optional[Union[str, torch.device]]): Device for the tensor.
        dtype (torch.dtype): Data type for the tensor. Default: torch.float32.
        seed (Optional[int]): Random seed for reproducibility.
    """

    def __init__(
        self,
        n_samples: int = 2000,
        n_classes: int = 5,
        noise: float = 0.05,
        radial_scale: float = 2.0,
        angular_scale: float = 0.1,
        spiral_scale: float = 5.0,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        seed: Optional[int] = None,
    ):
        if n_classes <= 0:
            raise ValueError("n_classes must be positive")
        self.n_classes = n_classes
        self.noise = noise
        self.radial_scale = radial_scale
        self.angular_scale = angular_scale
        self.spiral_scale = spiral_scale
        super().__init__(n_samples=n_samples, device=device, dtype=dtype, seed=seed)

    def _generate_data(self) -> torch.Tensor:
        # Logic from make_pinwheel
        all_points_np = []
        samples_per_class = self.n_samples // self.n_classes
        remainder = self.n_samples % self.n_classes

        for class_idx in range(self.n_classes):
            n_class_samples = samples_per_class + (1 if class_idx < remainder else 0)
            if n_class_samples == 0:
                continue

            t = np.sqrt(np.random.rand(n_class_samples))  # Radial density control
            radii = t * self.radial_scale
            base_angle = class_idx * (2 * np.pi / self.n_classes)
            spiral_angle = self.spiral_scale * t
            angle_noise = np.random.randn(n_class_samples) * self.angular_scale
            thetas = base_angle + spiral_angle + angle_noise

            x = radii * np.cos(thetas)
            y = radii * np.sin(thetas)
            all_points_np.append(np.stack([x, y], axis=1))

        data_np = np.concatenate(all_points_np, axis=0).astype(np.float32)
        np.random.shuffle(data_np)  # Shuffle before converting to tensor

        tensor_data = torch.from_numpy(data_np)

        if self.noise > 0:
            tensor_data += torch.randn_like(tensor_data) * self.noise

        return tensor_data

n_classes instance-attribute

n_classes = n_classes

noise instance-attribute

noise = noise

radial_scale instance-attribute

radial_scale = radial_scale

angular_scale instance-attribute

angular_scale = angular_scale

spiral_scale instance-attribute

spiral_scale = spiral_scale

_generate_data

_generate_data() -> torch.Tensor
Source code in torchebm/datasets/generators.py
def _generate_data(self) -> torch.Tensor:
    # Logic from make_pinwheel
    all_points_np = []
    samples_per_class = self.n_samples // self.n_classes
    remainder = self.n_samples % self.n_classes

    for class_idx in range(self.n_classes):
        n_class_samples = samples_per_class + (1 if class_idx < remainder else 0)
        if n_class_samples == 0:
            continue

        t = np.sqrt(np.random.rand(n_class_samples))  # Radial density control
        radii = t * self.radial_scale
        base_angle = class_idx * (2 * np.pi / self.n_classes)
        spiral_angle = self.spiral_scale * t
        angle_noise = np.random.randn(n_class_samples) * self.angular_scale
        thetas = base_angle + spiral_angle + angle_noise

        x = radii * np.cos(thetas)
        y = radii * np.sin(thetas)
        all_points_np.append(np.stack([x, y], axis=1))

    data_np = np.concatenate(all_points_np, axis=0).astype(np.float32)
    np.random.shuffle(data_np)  # Shuffle before converting to tensor

    tensor_data = torch.from_numpy(data_np)

    if self.noise > 0:
        tensor_data += torch.randn_like(tensor_data) * self.noise

    return tensor_data