Skip to content

LangevinDynamics

Methods and Attributes

Bases: BaseSampler

Langevin Dynamics sampler implementing discretized gradient-based MCMC.

This class implements the Langevin Dynamics algorithm, a gradient-based MCMC method that samples from a target distribution defined by an energy function. It uses a stochastic update rule combining gradient descent with Gaussian noise to explore the energy landscape.

Each step updates the state \(x_t\) according to the discretized Langevin equation:

\[x_{t+1} = x_t - \eta \nabla_x U(x_t) + \sqrt{2\eta} \epsilon_t\]

where \(\epsilon_t \sim \mathcal{N}(0, I)\) and \(\eta\) is the step size.

This process generates samples that asymptotically follow the Boltzmann distribution:

\[p(x) \propto e^{-U(x)}\]

where \(U(x)\) defines the energy landscape.

Algorithm Summary

  1. If x is not provided, initialize it with Gaussian noise.
  2. Iteratively update x for k_steps using self.langevin_step().
  3. Optionally track trajectory (return_trajectory=True).
  4. Optionally collect diagnostics such as mean, variance, and energy gradients.

Parameters:

Name Type Description Default
energy_function BaseEnergyFunction

Energy function to sample from.

required
step_size float

Step size for the Langevin update.

0.001
noise_scale float

Scale of the Gaussian noise.

1.0
decay float

Damping coefficient (not supported yet).

0.0
dtype dtype

Data type to use for the computations.

float32
device str

Device to run the computations on (e.g., "cpu" or "cuda").

None

Raises:

Type Description
ValueError

For invalid parameter ranges

Methods:

Name Description
langevin_step

Perform a Langevin step.

sample_chain

Run the sampling process.

_setup_diagnostics

Initialize the diagnostics

Basic Usage

# Define energy function
energy_fn = QuadraticEnergy(A=torch.eye(2), b=torch.zeros(2))

# Initialize sampler
sampler = LangevinDynamics(
    energy_function=energy_fn,
    step_size=0.01,
    noise_scale=0.1
)

# Sample 100 points from 5 parallel chains
samples = sampler.sample_chain(
    dim=2,
    k_steps=50,
    n_samples=100
)

Parameter Relationships

The effective temperature is controlled by: \(\text{Temperature} = \frac{\text{noise_scale}^2}{2 \cdot \text{step_size}}\) Adjust both parameters together to maintain constant temperature.

Source code in torchebm/samplers/langevin_dynamics.py
class LangevinDynamics(BaseSampler):
    r"""
    Langevin Dynamics sampler implementing discretized gradient-based MCMC.

    This class implements the Langevin Dynamics algorithm, a gradient-based MCMC method that samples from a target
    distribution defined by an energy function. It uses a stochastic update rule combining gradient descent with Gaussian noise to explore the energy landscape.

    Each step updates the state $x_t$ according to the discretized Langevin equation:

    $$x_{t+1} = x_t - \eta \nabla_x U(x_t) + \sqrt{2\eta} \epsilon_t$$

    where $\epsilon_t \sim \mathcal{N}(0, I)$ and $\eta$ is the step size.

    This process generates samples that asymptotically follow the Boltzmann distribution:


    $$p(x) \propto e^{-U(x)}$$

    where $U(x)$ defines the energy landscape.

    !!! note "Algorithm Summary"

        1. If `x` is not provided, initialize it with Gaussian noise.
        2. Iteratively update `x` for `k_steps` using `self.langevin_step()`.
        3. Optionally track trajectory (`return_trajectory=True`).
        4. Optionally collect diagnostics such as mean, variance, and energy gradients.

    Args:
        energy_function (BaseEnergyFunction): Energy function to sample from.
        step_size (float): Step size for the Langevin update.
        noise_scale (float): Scale of the Gaussian noise.
        decay (float): Damping coefficient (not supported yet).
        dtype (torch.dtype): Data type to use for the computations.
        device (str): Device to run the computations on (e.g., "cpu" or "cuda").

    Raises:
        ValueError: For invalid parameter ranges

    Methods:
        langevin_step(prev_x, noise): Perform a Langevin step.
        sample_chain(x, dim, k_steps, n_samples, return_trajectory, return_diagnostics): Run the sampling process.
        _setup_diagnostics(dim, k_steps, n_samples): Initialize the diagnostics

    !!! example "Basic Usage"
        ```python
        # Define energy function
        energy_fn = QuadraticEnergy(A=torch.eye(2), b=torch.zeros(2))

        # Initialize sampler
        sampler = LangevinDynamics(
            energy_function=energy_fn,
            step_size=0.01,
            noise_scale=0.1
        )

        # Sample 100 points from 5 parallel chains
        samples = sampler.sample_chain(
            dim=2,
            k_steps=50,
            n_samples=100
        )
        ```
    !!! warning "Parameter Relationships"
        The effective temperature is controlled by:
        \(\text{Temperature} = \frac{\text{noise_scale}^2}{2 \cdot \text{step_size}}\)
        Adjust both parameters together to maintain constant temperature.
    """

    def __init__(
        self,
        energy_function: BaseEnergyFunction,
        step_size: Union[float, BaseScheduler] = 1e-3,
        noise_scale: Union[float, BaseScheduler] = 1.0,
        decay: float = 0.0,
        dtype: torch.dtype = torch.float32,
        device: Optional[Union[str, torch.device]] = None,
    ):
        super().__init__(energy_function, dtype, device)

        # Register schedulers for step_size and noise_scale
        if isinstance(step_size, BaseScheduler):
            self.register_scheduler("step_size", step_size)
        else:
            if step_size <= 0:
                raise ValueError("step_size must be positive")
            self.register_scheduler(
                "step_size", ConstantScheduler(step_size)
            )

        if isinstance(noise_scale, BaseScheduler):
            self.register_scheduler("noise_scale", noise_scale)
        else:
            if noise_scale <= 0:
                raise ValueError("noise_scale must be positive")
            self.register_scheduler(
                "noise_scale", ConstantScheduler(noise_scale)
            )

        if device is not None:
            self.device = torch.device(device)
            energy_function = energy_function.to(self.device)
        else:
            self.device = torch.device("cpu")
        self.dtype = torch.float16 if self.device == "cuda" else torch.float32
        self.energy_function = energy_function
        self.step_size = step_size
        self.noise_scale = noise_scale
        self.decay = decay

    def langevin_step(self, prev_x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        r"""
        Perform a single Langevin dynamics update step.

        Implements the discrete Langevin equation:

        $$x_{t+1} = x_t - \eta \nabla_x U(x_t) + \sqrt{2\eta} \epsilon_t$$

        Args:
            prev_x (torch.Tensor): Current state tensor of batch_shape (batch_size, dim)
            noise (torch.Tensor): Gaussian noise tensor of batch_shape (batch_size, dim)

        Returns:
            torch.Tensor: Updated state tensor of same batch_shape as prev_x

        Example:
            ```python
            # Single step for 10 particles in 2D space
            current_state = torch.randn(10, 2)
            noise = torch.randn_like(current_state)
            next_state = langevin.langevin_step(current_state, noise)
            ```
        """

        step_size = self.get_scheduled_value("step_size")
        noise_scale = self.get_scheduled_value("noise_scale")

        gradient = self.energy_function.gradient(prev_x)

        # Apply noise scaling
        scaled_noise = noise_scale * noise

        # Apply proper step size and noise scaling
        new_x = (
            prev_x
            - step_size * gradient
            + torch.sqrt(torch.tensor(2.0 * step_size, device=prev_x.device))
            * scaled_noise
        )
        return new_x

    @torch.no_grad()
    def sample(
        self,
        x: Optional[torch.Tensor] = None,
        dim: int = 10,
        n_steps: int = 100,
        n_samples: int = 1,
        thin: int = 1,
        return_trajectory: bool = False,
        return_diagnostics: bool = False,
        *args,
        **kwargs,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]:
        """
        Generate Markov chain samples using Langevin dynamics.

        Args:
            x: Initial state to start the sampling from.
            dim: Dimension of the state space.
            n_steps: Number of steps to take between samples.
            n_samples: Number of samples to generate.
            return_trajectory: Whether to return the trajectory of the samples.
            return_diagnostics: Whether to return the diagnostics of the sampling process.

        Returns:
            Final samples:

                - If `return_trajectory=False` and `return_diagnostics=False`, returns the final
                  samples of batch_shape `(n_samples, dim)`.
                - If `return_trajectory=True`, returns a tensor of batch_shape `(n_samples, k_steps, dim)`,
                  containing the sampled trajectory.
                - If `return_diagnostics=True`, returns a tuple `(samples, diagnostics)`, where
                  `diagnostics` is a list of dictionaries storing per-step statistics.

        Raises:
            ValueError: If input dimensions mismatch

        Note:
            - Automatically handles device placement (CPU/GPU)
            - Uses mixed-precision training when available
            - Diagnostics include:
                * Mean and variance across dimensions
                * Energy gradients
                * Noise statistics

        Example:
            ```python
            # Generate 100 samples from 5 parallel chains
            samples = sampler.sample_chain(
                dim=32,
                k_steps=500,
                n_samples=100,
                return_diagnostics=True
            )
            ```

        """

        self.reset_schedulers()

        if x is None:
            x = torch.randn(n_samples, dim, dtype=self.dtype, device=self.device)
        else:
            x = x.to(self.device)  # Initial batch

        if return_trajectory:
            trajectory = torch.empty(
                (n_samples, n_steps, dim), dtype=self.dtype, device=self.device
            )

        if return_diagnostics:
            diagnostics = self._setup_diagnostics(dim, n_steps, n_samples=n_samples)

        with torch.amp.autocast(
            device_type="cuda" if self.device.type == "cuda" else "cpu"
        ):
            for i in range(n_steps):
                # todo: Add decay logic
                # Generate fresh noise for each step
                noise = torch.randn_like(x, device=self.device)

                # Step all schedulers before each MCMC step
                scheduler_values = self.step_schedulers()

                x = self.langevin_step(x, noise)

                if return_trajectory:
                    trajectory[:, i, :] = x

                if return_diagnostics:
                    # Handle mean and variance safely regardless of batch size
                    if n_samples > 1:
                        mean_x = x.mean(dim=0, keepdim=True)
                        var_x = x.var(dim=0, unbiased=False, keepdim=True)
                        var_x = torch.clamp(var_x, min=1e-10, max=1e10)
                    else:
                        # For single sample, just use the value and zeros for variance
                        mean_x = x.clone()
                        var_x = torch.zeros_like(x)

                    # Compute energy values
                    energy = self.energy_function(x)

                    # Store the diagnostics safely
                    for b in range(n_samples):
                        diagnostics[i, 0, b, :] = mean_x[b if n_samples > 1 else 0]
                        diagnostics[i, 1, b, :] = var_x[b if n_samples > 1 else 0]
                        diagnostics[i, 2, b, :] = energy[b].reshape(-1)

        if return_trajectory:
            if return_diagnostics:
                return trajectory, diagnostics
            return trajectory
        if return_diagnostics:
            return x, diagnostics
        return x

    def _setup_diagnostics(
        self, dim: int, n_steps: int, n_samples: int = None
    ) -> torch.Tensor:
        if n_samples is not None:
            return torch.empty(
                (n_steps, 3, n_samples, dim), device=self.device, dtype=self.dtype
            )
        else:
            return torch.empty((n_steps, 3, dim), device=self.device, dtype=self.dtype)

device instance-attribute

device = device(device)

dtype instance-attribute

dtype = float16 if device == 'cuda' else float32

energy_function instance-attribute

energy_function = energy_function

step_size instance-attribute

step_size = step_size

noise_scale instance-attribute

noise_scale = noise_scale

decay instance-attribute

decay = decay

langevin_step

langevin_step(prev_x: Tensor, noise: Tensor) -> torch.Tensor

Perform a single Langevin dynamics update step.

Implements the discrete Langevin equation:

\[x_{t+1} = x_t - \eta \nabla_x U(x_t) + \sqrt{2\eta} \epsilon_t\]

Parameters:

Name Type Description Default
prev_x Tensor

Current state tensor of batch_shape (batch_size, dim)

required
noise Tensor

Gaussian noise tensor of batch_shape (batch_size, dim)

required

Returns:

Type Description
Tensor

torch.Tensor: Updated state tensor of same batch_shape as prev_x

Example
# Single step for 10 particles in 2D space
current_state = torch.randn(10, 2)
noise = torch.randn_like(current_state)
next_state = langevin.langevin_step(current_state, noise)
Source code in torchebm/samplers/langevin_dynamics.py
def langevin_step(self, prev_x: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
    r"""
    Perform a single Langevin dynamics update step.

    Implements the discrete Langevin equation:

    $$x_{t+1} = x_t - \eta \nabla_x U(x_t) + \sqrt{2\eta} \epsilon_t$$

    Args:
        prev_x (torch.Tensor): Current state tensor of batch_shape (batch_size, dim)
        noise (torch.Tensor): Gaussian noise tensor of batch_shape (batch_size, dim)

    Returns:
        torch.Tensor: Updated state tensor of same batch_shape as prev_x

    Example:
        ```python
        # Single step for 10 particles in 2D space
        current_state = torch.randn(10, 2)
        noise = torch.randn_like(current_state)
        next_state = langevin.langevin_step(current_state, noise)
        ```
    """

    step_size = self.get_scheduled_value("step_size")
    noise_scale = self.get_scheduled_value("noise_scale")

    gradient = self.energy_function.gradient(prev_x)

    # Apply noise scaling
    scaled_noise = noise_scale * noise

    # Apply proper step size and noise scaling
    new_x = (
        prev_x
        - step_size * gradient
        + torch.sqrt(torch.tensor(2.0 * step_size, device=prev_x.device))
        * scaled_noise
    )
    return new_x

sample

sample(x: Optional[Tensor] = None, dim: int = 10, n_steps: int = 100, n_samples: int = 1, thin: int = 1, return_trajectory: bool = False, return_diagnostics: bool = False, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]

Generate Markov chain samples using Langevin dynamics.

Parameters:

Name Type Description Default
x Optional[Tensor]

Initial state to start the sampling from.

None
dim int

Dimension of the state space.

10
n_steps int

Number of steps to take between samples.

100
n_samples int

Number of samples to generate.

1
return_trajectory bool

Whether to return the trajectory of the samples.

False
return_diagnostics bool

Whether to return the diagnostics of the sampling process.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, List[dict]]]

Final samples:

  • If return_trajectory=False and return_diagnostics=False, returns the final samples of batch_shape (n_samples, dim).
  • If return_trajectory=True, returns a tensor of batch_shape (n_samples, k_steps, dim), containing the sampled trajectory.
  • If return_diagnostics=True, returns a tuple (samples, diagnostics), where diagnostics is a list of dictionaries storing per-step statistics.

Raises:

Type Description
ValueError

If input dimensions mismatch

Note
  • Automatically handles device placement (CPU/GPU)
  • Uses mixed-precision training when available
  • Diagnostics include:
    • Mean and variance across dimensions
    • Energy gradients
    • Noise statistics
Example
# Generate 100 samples from 5 parallel chains
samples = sampler.sample_chain(
    dim=32,
    k_steps=500,
    n_samples=100,
    return_diagnostics=True
)
Source code in torchebm/samplers/langevin_dynamics.py
@torch.no_grad()
def sample(
    self,
    x: Optional[torch.Tensor] = None,
    dim: int = 10,
    n_steps: int = 100,
    n_samples: int = 1,
    thin: int = 1,
    return_trajectory: bool = False,
    return_diagnostics: bool = False,
    *args,
    **kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[dict]]]:
    """
    Generate Markov chain samples using Langevin dynamics.

    Args:
        x: Initial state to start the sampling from.
        dim: Dimension of the state space.
        n_steps: Number of steps to take between samples.
        n_samples: Number of samples to generate.
        return_trajectory: Whether to return the trajectory of the samples.
        return_diagnostics: Whether to return the diagnostics of the sampling process.

    Returns:
        Final samples:

            - If `return_trajectory=False` and `return_diagnostics=False`, returns the final
              samples of batch_shape `(n_samples, dim)`.
            - If `return_trajectory=True`, returns a tensor of batch_shape `(n_samples, k_steps, dim)`,
              containing the sampled trajectory.
            - If `return_diagnostics=True`, returns a tuple `(samples, diagnostics)`, where
              `diagnostics` is a list of dictionaries storing per-step statistics.

    Raises:
        ValueError: If input dimensions mismatch

    Note:
        - Automatically handles device placement (CPU/GPU)
        - Uses mixed-precision training when available
        - Diagnostics include:
            * Mean and variance across dimensions
            * Energy gradients
            * Noise statistics

    Example:
        ```python
        # Generate 100 samples from 5 parallel chains
        samples = sampler.sample_chain(
            dim=32,
            k_steps=500,
            n_samples=100,
            return_diagnostics=True
        )
        ```

    """

    self.reset_schedulers()

    if x is None:
        x = torch.randn(n_samples, dim, dtype=self.dtype, device=self.device)
    else:
        x = x.to(self.device)  # Initial batch

    if return_trajectory:
        trajectory = torch.empty(
            (n_samples, n_steps, dim), dtype=self.dtype, device=self.device
        )

    if return_diagnostics:
        diagnostics = self._setup_diagnostics(dim, n_steps, n_samples=n_samples)

    with torch.amp.autocast(
        device_type="cuda" if self.device.type == "cuda" else "cpu"
    ):
        for i in range(n_steps):
            # todo: Add decay logic
            # Generate fresh noise for each step
            noise = torch.randn_like(x, device=self.device)

            # Step all schedulers before each MCMC step
            scheduler_values = self.step_schedulers()

            x = self.langevin_step(x, noise)

            if return_trajectory:
                trajectory[:, i, :] = x

            if return_diagnostics:
                # Handle mean and variance safely regardless of batch size
                if n_samples > 1:
                    mean_x = x.mean(dim=0, keepdim=True)
                    var_x = x.var(dim=0, unbiased=False, keepdim=True)
                    var_x = torch.clamp(var_x, min=1e-10, max=1e10)
                else:
                    # For single sample, just use the value and zeros for variance
                    mean_x = x.clone()
                    var_x = torch.zeros_like(x)

                # Compute energy values
                energy = self.energy_function(x)

                # Store the diagnostics safely
                for b in range(n_samples):
                    diagnostics[i, 0, b, :] = mean_x[b if n_samples > 1 else 0]
                    diagnostics[i, 1, b, :] = var_x[b if n_samples > 1 else 0]
                    diagnostics[i, 2, b, :] = energy[b].reshape(-1)

    if return_trajectory:
        if return_diagnostics:
            return trajectory, diagnostics
        return trajectory
    if return_diagnostics:
        return x, diagnostics
    return x

_setup_diagnostics

_setup_diagnostics(dim: int, n_steps: int, n_samples: int = None) -> torch.Tensor
Source code in torchebm/samplers/langevin_dynamics.py
def _setup_diagnostics(
    self, dim: int, n_steps: int, n_samples: int = None
) -> torch.Tensor:
    if n_samples is not None:
        return torch.empty(
            (n_steps, 3, n_samples, dim), device=self.device, dtype=self.dtype
        )
    else:
        return torch.empty((n_steps, 3, dim), device=self.device, dtype=self.dtype)