Skip to content

SwissRollDataset

Methods and Attributes

Bases: BaseSyntheticDataset

Generates a 2D Swiss roll dataset.

Parameters:

Name Type Description Default
n_samples int

Number of samples. Default: 2000.

2000
noise float

Standard deviation of Gaussian noise added. Default: 0.05.

0.05
arclength float

Controls how many rolls (pi*arclength). Default: 3.0.

3.0
device Optional[Union[str, device]]

Device for the tensor.

None
dtype dtype

Data type for the tensor. Default: torch.float32.

float32
seed Optional[int]

Random seed for reproducibility.

None
Source code in torchebm/datasets/generators.py
class SwissRollDataset(BaseSyntheticDataset):
    """
    Generates a 2D Swiss roll dataset.

    Args:
        n_samples (int): Number of samples. Default: 2000.
        noise (float): Standard deviation of Gaussian noise added. Default: 0.05.
        arclength (float): Controls how many rolls (pi*arclength). Default: 3.0.
        device (Optional[Union[str, torch.device]]): Device for the tensor.
        dtype (torch.dtype): Data type for the tensor. Default: torch.float32.
        seed (Optional[int]): Random seed for reproducibility.
    """

    def __init__(
        self,
        n_samples: int = 2000,
        noise: float = 0.05,
        arclength: float = 3.0,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        seed: Optional[int] = None,
    ):
        self.noise = noise
        self.arclength = arclength
        super().__init__(n_samples=n_samples, device=device, dtype=dtype, seed=seed)

    def _generate_data(self) -> torch.Tensor:
        # Logic from make_swiss_roll
        t = self.arclength * np.pi * (1 + 2 * np.random.rand(self.n_samples))
        x = t * np.cos(t)
        y = t * np.sin(t)
        X = np.vstack((x, y)).T.astype(np.float32)

        tensor_data = torch.from_numpy(X)  # CPU tensor initially
        tensor_data += torch.randn_like(tensor_data) * self.noise

        # Center and scale slightly (optional, can be done outside)
        tensor_data = (tensor_data - tensor_data.mean(dim=0)) / (
            tensor_data.std(dim=0).mean()
            * 2.0  # Be careful with division by zero if std is ~0
        )

        return tensor_data  # Return tensor, base class handles device/dtype

noise instance-attribute

noise = noise

arclength instance-attribute

arclength = arclength

_generate_data

_generate_data() -> torch.Tensor
Source code in torchebm/datasets/generators.py
def _generate_data(self) -> torch.Tensor:
    # Logic from make_swiss_roll
    t = self.arclength * np.pi * (1 + 2 * np.random.rand(self.n_samples))
    x = t * np.cos(t)
    y = t * np.sin(t)
    X = np.vstack((x, y)).T.astype(np.float32)

    tensor_data = torch.from_numpy(X)  # CPU tensor initially
    tensor_data += torch.randn_like(tensor_data) * self.noise

    # Center and scale slightly (optional, can be done outside)
    tensor_data = (tensor_data - tensor_data.mean(dim=0)) / (
        tensor_data.std(dim=0).mean()
        * 2.0  # Be careful with division by zero if std is ~0
    )

    return tensor_data  # Return tensor, base class handles device/dtype