Skip to content

Torchebm > Datasets > Generators

Contents

Classes

API Reference

torchebm.datasets.generators

Dataset Generators Module.

This module provides a collection of classes for generating synthetic datasets commonly used in testing and evaluating energy-based models. These generators create various 2D distributions with different characteristics, making them ideal for visualization and demonstration purposes. They are implemented as PyTorch Datasets for easy integration with DataLoaders.

Key Features

  • Diverse collection of 2D synthetic distributions via classes.
  • Configurable sample sizes, noise levels, and distribution parameters.
  • Direct compatibility with torch.utils.data.Dataset and DataLoader.
  • Device and dtype support for tensor outputs.
  • Reproducibility through random seeds.
  • Visualization support for generated datasets.

Module Components

Classes:

Name Description
BaseSyntheticDataset

Abstract base class for synthetic dataset generators.

GaussianMixtureDataset

Generates samples from a 2D Gaussian mixture arranged in a circle.

EightGaussiansDataset

Generates samples from a specific 8-component Gaussian mixture.

TwoMoonsDataset

Generates the classic "two moons" dataset.

SwissRollDataset

Generates a 2D Swiss roll dataset.

CircleDataset

Generates points sampled uniformly on a circle with noise.

CheckerboardDataset

Generates points in a 2D checkerboard pattern.

PinwheelDataset

Generates the pinwheel dataset with specified number of "blades".

GridDataset

Generates points on a regular 2D grid with optional noise.


Usage Example

Generating and Using Datasets

from torchebm.datasets.generators import TwoMoonsDataset, GaussianMixtureDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch

# Instantiate dataset objects
moons_dataset = TwoMoonsDataset(n_samples=1000, noise=0.05, seed=42)
mixture_dataset = GaussianMixtureDataset(n_samples=500, n_components=4, std=0.1, seed=123)

# Access the full dataset tensor
moons_data = moons_dataset.get_data()
mixture_data = mixture_dataset.get_data()
print(f"Two Moons data batch_shape: {moons_data.batch_shape}")
print(f"Mixture data batch_shape: {mixture_data.batch_shape}")

# Use with DataLoader
dataloader = DataLoader(moons_dataset, batch_size=32, shuffle=True)
first_batch = next(iter(dataloader))
print(f"First batch batch_shape: {first_batch.batch_shape}")

# Visualize the datasets
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.scatter(moons_data[:, 0], moons_data[:, 1], s=5)
plt.title("Two Moons")
plt.subplot(1, 2, 2)
plt.scatter(mixture_data[:, 0], mixture_data[:, 1], s=5)
plt.title("Gaussian Mixture")
plt.show()

Mathematical Background

Distribution Characteristics

Each dataset generator creates points from a different probability distribution:

  • Gaussian Mixtures: Weighted combinations of Gaussian distributions, often arranged in specific patterns like circles.
  • Two Moons: Two interlocking half-circles with added noise, creating a challenging bimodal distribution that's not linearly separable.
  • Checkerboard: Alternating high and low density regions in a grid pattern, testing an EBM's ability to capture multiple modes in a regular structure.
  • Swiss Roll: A 2D manifold with spiral structure, testing the model's ability to learn curved manifolds.

Choosing a Dataset

  • For testing basic density estimation: use GaussianMixtureDataset
  • For evaluating mode-seeking behavior: use EightGaussiansDataset or CheckerboardDataset
  • For testing separation of entangled distributions: use TwoMoonsDataset
  • For manifold learning: use SwissRollDataset or CircleDataset

Implementation Details

Device Handling

The device parameter determines where the generated dataset tensor resides. Set this appropriately (e.g., 'cuda') for efficient GPU usage.

Random Number Generation

The generators use PyTorch and NumPy random functions. For reproducible results, provide a seed argument during instantiation.