class GridDataset(BaseSyntheticDataset):
"""
Generates points on a 2D grid.
Note: The total number of samples will be `n_samples_per_dim` ** 2.
Args:
n_samples_per_dim (int): The number of points along each dimension.
range_limit (float): Defines the square region `[-lim, lim] x [-lim, lim]`.
noise (float): The standard deviation of the Gaussian noise to add.
device (Optional[Union[str, torch.device]]): The device for the tensor.
dtype (torch.dtype): The data type for the tensor.
seed (Optional[int]): A random seed for reproducibility (primarily affects noise).
"""
def __init__(
self,
n_samples_per_dim: int = 10,
range_limit: float = 1.0,
noise: float = 0.01,
device: Optional[Union[str, torch.device]] = None,
dtype: torch.dtype = torch.float32,
seed: Optional[int] = None, # Seed mainly affects noise here
):
if n_samples_per_dim <= 0:
raise ValueError("n_samples_per_dim must be positive")
self.n_samples_per_dim = n_samples_per_dim
self.range_limit = range_limit
self.noise = noise
# Override n_samples for the base class
total_samples = n_samples_per_dim * n_samples_per_dim
super().__init__(n_samples=total_samples, device=device, dtype=dtype, seed=seed)
def _generate_data(self) -> torch.Tensor:
# Create a more uniform grid spacing
x_coords = torch.linspace(
-self.range_limit, self.range_limit, self.n_samples_per_dim
)
y_coords = torch.linspace(
-self.range_limit, self.range_limit, self.n_samples_per_dim
)
# Create the grid points
grid_x, grid_y = torch.meshgrid(x_coords, y_coords, indexing="ij")
# Stack the coordinates to form the 2D points
points = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
# Apply noise if specified
if self.noise > 0:
# Set the random seed if provided
if hasattr(self, "rng"):
# Use the RNG from the base class
noise = torch.tensor(
self.rng.normal(0, self.noise, size=points.shape),
dtype=points.dtype,
device=points.device,
)
else:
# Fall back to torch's random generator
noise = torch.randn_like(points) * self.noise
points = points + noise
return points