Skip to content

GridDataset

Methods and Attributes

Bases: BaseSyntheticDataset

Generates points on a 2D grid within [-range_limit, range_limit].

The total number of samples will be n_samples_per_dim ** 2.

The n_samples parameter in the base class will be overridden.

Parameters:

Name Type Description Default
n_samples_per_dim int

Number of points along each dimension. Default: 10.

10
range_limit float

Defines the square region [-lim, lim] x [-lim, lim]. Default: 1.0.

1.0
noise float

Standard deviation of Gaussian noise added. Default: 0.01.

0.01
device Optional[Union[str, device]]

Device for the tensor.

None
dtype dtype

Data type for the tensor. Default: torch.float32.

float32
seed Optional[int]

Random seed (primarily affects noise).

None
Source code in torchebm/datasets/generators.py
class GridDataset(BaseSyntheticDataset):
    """Generates points on a 2D grid within [-range_limit, range_limit].

    Note: The total number of samples will be n_samples_per_dim ** 2.
          The `n_samples` parameter in the base class will be overridden.

    Args:
        n_samples_per_dim (int): Number of points along each dimension. Default: 10.
        range_limit (float): Defines the square region [-lim, lim] x [-lim, lim]. Default: 1.0.
        noise (float): Standard deviation of Gaussian noise added. Default: 0.01.
        device (Optional[Union[str, torch.device]]): Device for the tensor.
        dtype (torch.dtype): Data type for the tensor. Default: torch.float32.
        seed (Optional[int]): Random seed (primarily affects noise).
    """

    def __init__(
        self,
        n_samples_per_dim: int = 10,
        range_limit: float = 1.0,
        noise: float = 0.01,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        seed: Optional[int] = None,  # Seed mainly affects noise here
    ):
        if n_samples_per_dim <= 0:
            raise ValueError("n_samples_per_dim must be positive")
        self.n_samples_per_dim = n_samples_per_dim
        self.range_limit = range_limit
        self.noise = noise
        # Override n_samples for the base class
        total_samples = n_samples_per_dim * n_samples_per_dim
        super().__init__(n_samples=total_samples, device=device, dtype=dtype, seed=seed)

    def _generate_data(self) -> torch.Tensor:
        # Create a more uniform grid spacing
        x_coords = torch.linspace(
            -self.range_limit, self.range_limit, self.n_samples_per_dim
        )
        y_coords = torch.linspace(
            -self.range_limit, self.range_limit, self.n_samples_per_dim
        )

        # Create the grid points
        grid_x, grid_y = torch.meshgrid(x_coords, y_coords, indexing="ij")

        # Stack the coordinates to form the 2D points
        points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)

        # Apply noise if specified
        if self.noise > 0:
            # Set the random seed if provided
            if hasattr(self, "rng"):
                # Use the RNG from the base class
                noise = torch.tensor(
                    self.rng.normal(0, self.noise, size=points.shape),
                    dtype=points.dtype,
                    device=points.device,
                )
            else:
                # Fall back to torch's random generator
                noise = torch.randn_like(points) * self.noise

            points = points + noise

        return points

n_samples_per_dim instance-attribute

n_samples_per_dim = n_samples_per_dim

range_limit instance-attribute

range_limit = range_limit

noise instance-attribute

noise = noise

_generate_data

_generate_data() -> torch.Tensor
Source code in torchebm/datasets/generators.py
def _generate_data(self) -> torch.Tensor:
    # Create a more uniform grid spacing
    x_coords = torch.linspace(
        -self.range_limit, self.range_limit, self.n_samples_per_dim
    )
    y_coords = torch.linspace(
        -self.range_limit, self.range_limit, self.n_samples_per_dim
    )

    # Create the grid points
    grid_x, grid_y = torch.meshgrid(x_coords, y_coords, indexing="ij")

    # Stack the coordinates to form the 2D points
    points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)

    # Apply noise if specified
    if self.noise > 0:
        # Set the random seed if provided
        if hasattr(self, "rng"):
            # Use the RNG from the base class
            noise = torch.tensor(
                self.rng.normal(0, self.noise, size=points.shape),
                dtype=points.dtype,
                device=points.device,
            )
        else:
            # Fall back to torch's random generator
            noise = torch.randn_like(points) * self.noise

        points = points + noise

    return points