Skip to content

BaseSyntheticDataset

Methods and Attributes

Bases: Dataset, ABC

Abstract Base Class for generating 2D synthetic datasets.

Provides common functionality for handling sample size, device, dtype, seeding, and PyTorch Dataset integration. Subclasses must implement the _generate_data method.

Parameters:

Name Type Description Default
n_samples int

Total number of samples to generate.

required
device Optional[Union[str, device]]

Device to place the tensor on.

None
dtype dtype

Data type for the output tensor. Default: torch.float32.

float32
seed Optional[int]

Random seed for reproducibility. If None, generation will be non-deterministic.

None
Source code in torchebm/datasets/generators.py
class BaseSyntheticDataset(Dataset, ABC):
    """
    Abstract Base Class for generating 2D synthetic datasets.

    Provides common functionality for handling sample size, device, dtype,
    seeding, and PyTorch Dataset integration. Subclasses must implement
    the `_generate_data` method.

    Args:
        n_samples (int): Total number of samples to generate.
        device (Optional[Union[str, torch.device]]): Device to place the tensor on.
        dtype (torch.dtype): Data type for the output tensor. Default: torch.float32.
        seed (Optional[int]): Random seed for reproducibility. If None, generation
            will be non-deterministic.
    """

    def __init__(
        self,
        n_samples: int,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        seed: Optional[int] = None,
    ):
        if n_samples <= 0:
            raise ValueError("n_samples must be positive")

        self.n_samples = n_samples
        self.device = device
        self.dtype = dtype
        self.seed = seed
        self.data: Optional[torch.Tensor] = None  # Data will be stored here
        self._generate()  # Generate data upon initialization

    def _seed_generators(self):
        """Sets the random seeds for numpy and torch if a seed is provided."""
        if self.seed is not None:
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            # If using CUDA, also seed the CUDA generator
            if torch.cuda.is_available() and (
                isinstance(self.device, torch.device)
                and self.device.type == "cuda"
                or self.device == "cuda"
            ):
                torch.cuda.manual_seed_all(self.seed)  # Seed all GPUs

    @abstractmethod
    def _generate_data(self) -> torch.Tensor:
        """
        Core data generation logic to be implemented by subclasses.

        This method should perform the actual sampling based on the dataset's
        specific parameters and return the generated data as a NumPy array
        or directly as a PyTorch tensor (without final device/dtype conversion,
        as that's handled by the base class).
        """
        pass

    def _generate(self):
        """Internal method to handle seeding and call the generation logic."""
        self._seed_generators()
        # Generate data using the subclass implementation
        generated_output = self._generate_data()

        # Ensure it's a tensor and on the correct device/dtype
        if isinstance(generated_output, np.ndarray):
            self.data = _to_tensor(
                generated_output, dtype=self.dtype, device=self.device
            )
        elif isinstance(generated_output, torch.Tensor):
            self.data = generated_output.to(dtype=self.dtype, device=self.device)
        else:
            raise TypeError(
                f"_generate_data must return a NumPy array or PyTorch Tensor, got {type(generated_output)}"
            )

        # Verify shape
        if self.data.shape[0] != self.n_samples:
            warnings.warn(
                f"Generated data has {self.data.shape[0]} samples, but {self.n_samples} were requested. Check generation logic.",
                RuntimeWarning,
            )
            # Optional: adjust self.n_samples or raise error depending on desired strictness
            # self.n_samples = self.data.shape[0]

    def regenerate(self, seed: Optional[int] = None):
        """
        Re-generates the dataset, optionally with a new seed.

        Args:
            seed (Optional[int]): New random seed. If None, uses the original seed
                                  (if provided) or remains non-deterministic.
        """
        if seed is not None:
            self.seed = seed  # Update the seed if a new one is provided
        self._generate()

    def get_data(self) -> torch.Tensor:
        """
        Returns the entire generated dataset as a single tensor.

        Returns:
            torch.Tensor: The generated data tensor.
        """
        if self.data is None:
            # Should not happen if _generate() is called in __init__
            self._generate()
        return self.data

    def __len__(self) -> int:
        """Returns the number of samples in the dataset."""
        return self.n_samples

    def __getitem__(self, idx: int) -> torch.Tensor:
        """
        Returns the sample at the specified index.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            torch.Tensor: The data sample at the given index.
        """
        if self.data is None:
            self._generate()  # Ensure data exists

        if not 0 <= idx < self.n_samples:
            raise IndexError(
                f"Index {idx} out of bounds for dataset with size {self.n_samples}"
            )
        return self.data[idx]

    def __repr__(self) -> str:
        """String representation of the dataset object."""
        params = [f"n_samples={self.n_samples}"]
        # Add specific params from subclasses if desired, e.g. by inspecting self.__dict__
        # Or define __repr__ in subclasses
        return f"{self.__class__.__name__}({', '.join(params)}, device={self.device}, dtype={self.dtype})"

n_samples instance-attribute

n_samples = n_samples

device instance-attribute

device = device

dtype instance-attribute

dtype = dtype

seed instance-attribute

seed = seed

data instance-attribute

data: Optional[Tensor] = None

_seed_generators

_seed_generators()

Sets the random seeds for numpy and torch if a seed is provided.

Source code in torchebm/datasets/generators.py
def _seed_generators(self):
    """Sets the random seeds for numpy and torch if a seed is provided."""
    if self.seed is not None:
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # If using CUDA, also seed the CUDA generator
        if torch.cuda.is_available() and (
            isinstance(self.device, torch.device)
            and self.device.type == "cuda"
            or self.device == "cuda"
        ):
            torch.cuda.manual_seed_all(self.seed)  # Seed all GPUs

_generate_data abstractmethod

_generate_data() -> torch.Tensor

Core data generation logic to be implemented by subclasses.

This method should perform the actual sampling based on the dataset's specific parameters and return the generated data as a NumPy array or directly as a PyTorch tensor (without final device/dtype conversion, as that's handled by the base class).

Source code in torchebm/datasets/generators.py
@abstractmethod
def _generate_data(self) -> torch.Tensor:
    """
    Core data generation logic to be implemented by subclasses.

    This method should perform the actual sampling based on the dataset's
    specific parameters and return the generated data as a NumPy array
    or directly as a PyTorch tensor (without final device/dtype conversion,
    as that's handled by the base class).
    """
    pass

_generate

_generate()

Internal method to handle seeding and call the generation logic.

Source code in torchebm/datasets/generators.py
def _generate(self):
    """Internal method to handle seeding and call the generation logic."""
    self._seed_generators()
    # Generate data using the subclass implementation
    generated_output = self._generate_data()

    # Ensure it's a tensor and on the correct device/dtype
    if isinstance(generated_output, np.ndarray):
        self.data = _to_tensor(
            generated_output, dtype=self.dtype, device=self.device
        )
    elif isinstance(generated_output, torch.Tensor):
        self.data = generated_output.to(dtype=self.dtype, device=self.device)
    else:
        raise TypeError(
            f"_generate_data must return a NumPy array or PyTorch Tensor, got {type(generated_output)}"
        )

    # Verify shape
    if self.data.shape[0] != self.n_samples:
        warnings.warn(
            f"Generated data has {self.data.shape[0]} samples, but {self.n_samples} were requested. Check generation logic.",
            RuntimeWarning,
        )

regenerate

regenerate(seed: Optional[int] = None)

Re-generates the dataset, optionally with a new seed.

Parameters:

Name Type Description Default
seed Optional[int]

New random seed. If None, uses the original seed (if provided) or remains non-deterministic.

None
Source code in torchebm/datasets/generators.py
def regenerate(self, seed: Optional[int] = None):
    """
    Re-generates the dataset, optionally with a new seed.

    Args:
        seed (Optional[int]): New random seed. If None, uses the original seed
                              (if provided) or remains non-deterministic.
    """
    if seed is not None:
        self.seed = seed  # Update the seed if a new one is provided
    self._generate()

get_data

get_data() -> torch.Tensor

Returns the entire generated dataset as a single tensor.

Returns:

Type Description
Tensor

torch.Tensor: The generated data tensor.

Source code in torchebm/datasets/generators.py
def get_data(self) -> torch.Tensor:
    """
    Returns the entire generated dataset as a single tensor.

    Returns:
        torch.Tensor: The generated data tensor.
    """
    if self.data is None:
        # Should not happen if _generate() is called in __init__
        self._generate()
    return self.data