TorchEBM API Reference¶
Welcome to the TorchEBM API reference documentation. This section provides detailed information about the classes and functions available in TorchEBM.
Package Structure¶
TorchEBM is organized into several modules:
-
Core
Base classes and core functionality for energy functions, samplers, and trainers.
-
Samplers
Sampling algorithms for energy-based models including Langevin Dynamics and MCMC.
-
Losses
Loss functions for training energy-based models.
-
Utils
Utility functions for working with energy-based models.
-
:material-gpu:{ .lg .middle } CUDA
CUDA-accelerated implementations for faster computation.
Getting Started with the API¶
If you're new to TorchEBM, we recommend starting with the following classes:
EnergyFunction
: Base class for all energy functionsBaseSampler
: Base class for all sampling algorithmsLangevinDynamics
: Implementation of Langevin dynamics sampling
Core Components¶
Energy Functions¶
TorchEBM provides various built-in energy functions:
Energy Function | Description |
---|---|
GaussianEnergy |
Multivariate Gaussian energy function |
DoubleWellEnergy |
Double well potential energy function |
RastriginEnergy |
Rastrigin function for testing optimization algorithms |
RosenbrockEnergy |
Rosenbrock function (banana function) |
AckleyEnergy |
Ackley function, a multimodal test function |
HarmonicEnergy |
Harmonic oscillator energy function |
Samplers¶
Available sampling algorithms:
Sampler | Description |
---|---|
LangevinDynamics |
Langevin dynamics sampling algorithm |
HamiltonianMonteCarlo |
Hamiltonian Monte Carlo sampling |
Loss Functions¶
TorchEBM implements several loss functions for training EBMs:
Loss Function | Description |
---|---|
ContrastiveDivergence |
Standard contrastive divergence (CD-k) |
PersistentContrastiveDivergence |
Persistent contrastive divergence |
ParallelTemperingCD |
Parallel tempering contrastive divergence |
Module Details¶
For detailed information about each module, follow the links below: