Skip to content

Getting Started with TorchEBM

Welcome to the TorchEBM guides section! These comprehensive guides will help you understand how to use TorchEBM effectively for your energy-based modeling tasks.

Core Concepts

  • Energy Functions


    Learn about the foundation of energy-based models and how to work with energy functions in TorchEBM.

    Energy Functions Guide

  • Samplers


    Discover how to generate samples from energy landscapes using various sampling algorithms.

    Samplers Guide

  • Loss Functions


    Explore different loss functions for training energy-based models.

    Loss Functions Guide

  • Custom Neural Networks


    Learn how to create and use neural networks as energy functions.

    Custom Neural Networks Guide

  • Training EBMs


    Master the techniques for effectively training energy-based models.

    Training Guide

  • Visualization


    Visualize energy landscapes and sampling results to gain insights.

    Visualization Guide

Quick Start

If you're new to energy-based models, we recommend the following learning path:

  1. Start with the Introduction to understand basic concepts
  2. Follow the Installation guide to set up TorchEBM
  3. Read the Energy Functions guide to understand the fundamentals
  4. Explore the Samplers guide to learn how to generate samples
  5. Study the Training guide to learn how to train your models

Basic Example

Here's a simple example to get you started with TorchEBM:

import torch
from torchebm.core import GaussianEnergy
from torchebm.samplers.langevin_dynamics import LangevinDynamics

# Create an energy function (2D Gaussian)
energy_fn = GaussianEnergy(
    mean=torch.zeros(2),
    cov=torch.eye(2)
)

# Create a sampler
sampler = LangevinDynamics(
    energy_function=energy_fn,
    step_size=0.01
)

# Generate samples
samples = sampler.sample_chain(
    dim=2, n_steps=100, n_samples=1000
)

# Print sample statistics
print(f"Sample mean: {samples.mean(0)}")
print(f"Sample std: {samples.std(0)}")

Common Patterns

Here are some common patterns you'll encounter throughout the guides:

Energy Function Definition

from torchebm.core import EnergyFunction
import torch

class MyEnergyFunction(EnergyFunction):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sum(x**2, dim=-1)

Sampler Usage

from torchebm.samplers.langevin_dynamics import LangevinDynamics

sampler = LangevinDynamics(
    energy_function=energy_fn,
    step_size=0.01
)

samples = sampler.sample_chain(
    dim=2, n_steps=100, n_samples=1000
)

Next Steps

Once you're familiar with the basics, you can:

Remember that all examples in these guides are tested with the latest version of TorchEBM, and you can run them in your own environment to gain hands-on experience.