Torchebm > Datasets > Generators¶
Contents¶
Classes¶
BaseSyntheticDataset
- Abstract Base Class for generating 2D synthetic datasets.CheckerboardDataset
- Generates points in a 2D checkerboard pattern using rejection sampling.CircleDataset
- Generates points sampled uniformly on a circle with noise.EightGaussiansDataset
- Generates samples from the specific '8 Gaussians' mixture.GaussianMixtureDataset
- Generates samples from a 2D Gaussian mixture arranged uniformly in a circle.GridDataset
- Generates points on a 2D grid within [-range_limit, range_limit].PinwheelDataset
- Generates the pinwheel dataset with curved blades.SwissRollDataset
- Generates a 2D Swiss roll dataset.TwoMoonsDataset
- Generates the 'two moons' dataset.
API Reference¶
torchebm.datasets.generators
¶
Dataset Generators Module.
This module provides a collection of classes for generating synthetic datasets commonly used in testing and evaluating energy-based models. These generators create various 2D distributions with different characteristics, making them ideal for visualization and demonstration purposes. They are implemented as PyTorch Datasets for easy integration with DataLoaders.
Key Features
- Diverse collection of 2D synthetic distributions via classes.
- Configurable sample sizes, noise levels, and distribution parameters.
- Direct compatibility with
torch.utils.data.Dataset
andDataLoader
. - Device and dtype support for tensor outputs.
- Reproducibility through random seeds.
- Visualization support for generated datasets.
Module Components¶
Classes:
Name | Description |
---|---|
BaseSyntheticDataset |
Abstract base class for synthetic dataset generators. |
GaussianMixtureDataset |
Generates samples from a 2D Gaussian mixture arranged in a circle. |
EightGaussiansDataset |
Generates samples from a specific 8-component Gaussian mixture. |
TwoMoonsDataset |
Generates the classic "two moons" dataset. |
SwissRollDataset |
Generates a 2D Swiss roll dataset. |
CircleDataset |
Generates points sampled uniformly on a circle with noise. |
CheckerboardDataset |
Generates points in a 2D checkerboard pattern. |
PinwheelDataset |
Generates the pinwheel dataset with specified number of "blades". |
GridDataset |
Generates points on a regular 2D grid with optional noise. |
Usage Example¶
Generating and Using Datasets
from torchebm.datasets.generators import TwoMoonsDataset, GaussianMixtureDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch
# Instantiate dataset objects
moons_dataset = TwoMoonsDataset(n_samples=1000, noise=0.05, seed=42)
mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123)
# Access the full dataset tensor
moons_data = moons_dataset.get_data()
mixture_data = mixture_dataset.get_data()
print(f"Two Moons data batch_shape: {moons_data.batch_shape}")
print(f"Mixture data batch_shape: {mixture_data.batch_shape}")
# Use with DataLoader
dataloader = DataLoader(moons_dataset, batch_size=32, shuffle=True)
first_batch = next(iter(dataloader))
print(f"First batch batch_shape: {first_batch.batch_shape}")
# Visualize the datasets
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.scatter(moons_data[:, 0], moons_data[:, 1], s=5)
plt.title("Two Moons")
plt.subplot(1, 2, 2)
plt.scatter(mixture_data[:, 0], mixture_data[:, 1], s=5)
plt.title("Gaussian Mixture")
plt.show()
Mathematical Background¶
Distribution Characteristics
Each dataset generator creates points from a different probability distribution:
- Gaussian Mixtures: Weighted combinations of Gaussian distributions, often arranged in specific patterns like circles.
- Two Moons: Two interlocking half-circles with added noise, creating a challenging bimodal distribution that's not linearly separable.
- Checkerboard: Alternating high and low density regions in a grid pattern, testing an EBM's ability to capture multiple modes in a regular structure.
- Swiss Roll: A 2D manifold with spiral structure, testing the model's ability to learn curved manifolds.
Choosing a Dataset
- For testing basic density estimation: use
GaussianMixtureDataset
- For evaluating mode-seeking behavior: use
EightGaussiansDataset
orCheckerboardDataset
- For testing separation of entangled distributions: use
TwoMoonsDataset
- For manifold learning: use
SwissRollDataset
orCircleDataset
Implementation Details¶
Device Handling
The device
parameter determines where the generated dataset tensor resides.
Set this appropriately (e.g., 'cuda') for efficient GPU usage.
Random Number Generation
The generators use PyTorch and NumPy random functions. For reproducible results,
provide a seed
argument during instantiation.