Working with Datasets¶
TorchEBM includes a suite of synthetic datasets primarily for 2D distributions, which are invaluable for testing models and algorithms. These are available in the torchebm.datasets module and are implemented as standard PyTorch Dataset classes, making them fully compatible with DataLoader.
Synthetic 2D Datasets¶
These datasets are useful for visualizing how an EBM learns complex, multimodal distributions.
-
Gaussian Mixture
A mixture of eight distinct Gaussian distributions arranged in a circle.
-
Eight Gaussians
A classic synthetic dataset with eight Gaussian modes in a circular pattern.
-
Two Moons
Two interleaving half-circles, a common benchmark for nonlinear distributions.
-
Swiss Roll
A spiral-shaped manifold, useful for testing manifold learning algorithms.
-
Circle
Data points distributed on the circumference of a circle.
-
Checkerboard
A grid-like pattern of clusters, challenging for models to capture.
-
Pinwheel
A dataset with swirling arms, testing a model's ability to learn rotational structures.
-
2D Grid
A uniform grid of points, useful for evaluating coverage and mode detection.
Using with DataLoader¶
Since these are torch.utils.data.Dataset subclasses, they integrate seamlessly with DataLoader for batching during training.
Training Example¶
Here’s a brief example of how to use a dataset to train an EBM. This is a condensed version of the full training process covered in the next chapter.
DataLoader, and using the batches to train a model with a sampler and a loss function like ContrastiveDivergence. 






