Skip to content

torchebm.integrators.euler_maruyama

Euler-Maruyama integrator.

EulerMaruyamaIntegrator

Bases: BaseSDERungeKuttaIntegrator

Euler-Maruyama integrator for Itô SDEs and ODEs.

The SDE form is:

\[ \mathrm{d}x = f(x,t)\,\mathrm{d}t + \sqrt{2D(x,t)}\,\mathrm{d}W_t \]

When diffusion is omitted, this reduces to the Euler method for ODEs.

Update rule:

\[ x_{n+1} = x_n + f(x_n, t_n)\Delta t + \sqrt{2D(x_n,t_n)}\,\Delta W_n \]

Parameters:

Name Type Description Default
device Optional[device]

Device for computations.

None
dtype Optional[dtype]

Data type for computations.

None
atol float

Absolute tolerance for adaptive stepping.

1e-06
rtol float

Relative tolerance for adaptive stepping.

0.001
max_steps int

Maximum number of steps before raising RuntimeError.

10000
safety float

Safety factor for step-size adjustment (< 1).

0.9
min_factor float

Minimum step-size shrink factor.

0.2
max_factor float

Maximum step-size growth factor.

10.0
max_step_size float

Maximum absolute step size during adaptive integration.

float('inf')
norm Optional[Callable[[Tensor], Tensor]]

Callable norm(tensor) -> scalar for local error measurement.

None
Example
1
2
3
4
5
6
7
8
9
from torchebm.integrators import EulerMaruyamaIntegrator
import torch

integrator = EulerMaruyamaIntegrator()
state = {"x": torch.randn(100, 2)}
drift = lambda x, t: -x  # simple mean-reverting drift
result = integrator.step(
    state, step_size=0.01, drift=drift, noise_scale=1.0
)
Source code in torchebm/integrators/euler_maruyama.py
class EulerMaruyamaIntegrator(BaseSDERungeKuttaIntegrator):
    r"""
    Euler-Maruyama integrator for Itô SDEs and ODEs.

    The SDE form is:

    \[
    \mathrm{d}x = f(x,t)\,\mathrm{d}t + \sqrt{2D(x,t)}\,\mathrm{d}W_t
    \]

    When `diffusion` is omitted, this reduces to the Euler method for ODEs.

    Update rule:

    \[
    x_{n+1} = x_n + f(x_n, t_n)\Delta t + \sqrt{2D(x_n,t_n)}\,\Delta W_n
    \]

    Args:
        device: Device for computations.
        dtype: Data type for computations.
        atol: Absolute tolerance for adaptive stepping.
        rtol: Relative tolerance for adaptive stepping.
        max_steps: Maximum number of steps before raising ``RuntimeError``.
        safety: Safety factor for step-size adjustment (< 1).
        min_factor: Minimum step-size shrink factor.
        max_factor: Maximum step-size growth factor.
        max_step_size: Maximum absolute step size during adaptive integration.
        norm: Callable ``norm(tensor) -> scalar`` for local error measurement.

    Example:
        ```python
        from torchebm.integrators import EulerMaruyamaIntegrator
        import torch

        integrator = EulerMaruyamaIntegrator()
        state = {"x": torch.randn(100, 2)}
        drift = lambda x, t: -x  # simple mean-reverting drift
        result = integrator.step(
            state, step_size=0.01, drift=drift, noise_scale=1.0
        )
        ```
    """

    @property
    def tableau_a(self):
        return ((),)

    @property
    def tableau_b(self):
        return (1.0,)

    @property
    def tableau_c(self):
        return (0.0,)

    # -- backward-compat shims for deprecated ``model`` kwarg ----------------

    @staticmethod
    def _resolve_model_to_drift(model, drift):
        """Convert deprecated ``model`` to a ``drift`` callable."""
        if model is not None:
            warnings.warn(
                "Passing 'model' to EulerMaruyamaIntegrator is deprecated. "
                "Use drift=lambda x, t: -model.gradient(x) instead.",
                DeprecationWarning,
                stacklevel=3,
            )
            if drift is None:
                drift = lambda x_, t_: -model.gradient(x_)
        return drift

    def step(
        self,
        state: Dict[str, torch.Tensor],
        step_size: torch.Tensor = None,
        *,
        model=None,
        drift: Optional[
            Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
        ] = None,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        drift = self._resolve_model_to_drift(model, drift)
        return super().step(state, step_size, drift=drift, **kwargs)

    def integrate(
        self,
        state: Dict[str, torch.Tensor],
        step_size: torch.Tensor = None,
        n_steps: int = None,
        *,
        model=None,
        drift: Optional[
            Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
        ] = None,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        drift = self._resolve_model_to_drift(model, drift)
        return super().integrate(state, step_size, n_steps, drift=drift, **kwargs)