Skip to content

torchebm.utils

Utility functions for TorchEBM.

center_crop_arr(pil_image, image_size)

Center crop and resize image to target size.

Parameters:

Name Type Description Default
pil_image Image

PIL image to crop.

required
image_size int

Target size for square crop.

required

Returns:

Type Description
Image

Center-cropped PIL image.

Source code in torchebm/utils/image.py
def center_crop_arr(pil_image: Image.Image, image_size: int) -> Image.Image:
    r"""Center crop and resize image to target size.

    Args:
        pil_image: PIL image to crop.
        image_size: Target size for square crop.

    Returns:
        Center-cropped PIL image.
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(
        arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
    )

create_npz_from_sample_folder(sample_dir, num=50000)

Build .npz file from folder of PNG samples.

Parameters:

Name Type Description Default
sample_dir str

Directory containing numbered PNG files.

required
num int

Number of samples to include.

50000

Returns:

Type Description
str

Path to created .npz file.

Source code in torchebm/utils/image.py
def create_npz_from_sample_folder(sample_dir: str, num: int = 50000) -> str:
    r"""Build .npz file from folder of PNG samples.

    Args:
        sample_dir: Directory containing numbered PNG files.
        num: Number of samples to include.

    Returns:
        Path to created .npz file.
    """
    samples = []
    for i in range(num):
        sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
        sample_np = np.asarray(sample_pil).astype(np.uint8)
        samples.append(sample_np)

    samples = np.stack(samples)
    npz_path = f"{sample_dir}.npz"
    np.savez(npz_path, arr_0=samples)
    print(f"Saved .npz file to {npz_path} [shape={samples.shape}]")
    return npz_path

load_checkpoint(checkpoint_path, model, ema_model=None, optimizer=None, device=None)

Load training checkpoint.

Parameters:

Name Type Description Default
checkpoint_path str

Path to checkpoint file.

required
model Module

Model to load weights into.

required
ema_model Optional[Module]

EMA model to load (optional).

None
optimizer Optional[Optimizer]

Optimizer to load state (optional).

None
device Optional[device]

Device to map tensors to.

None

Returns:

Type Description
Dict[str, Any]

Dictionary with checkpoint contents.

Source code in torchebm/utils/training.py
def load_checkpoint(
    checkpoint_path: str,
    model: nn.Module,
    ema_model: Optional[nn.Module] = None,
    optimizer: Optional[torch.optim.Optimizer] = None,
    device: Optional[torch.device] = None,
) -> Dict[str, Any]:
    r"""Load training checkpoint.

    Args:
        checkpoint_path: Path to checkpoint file.
        model: Model to load weights into.
        ema_model: EMA model to load (optional).
        optimizer: Optimizer to load state (optional).
        device: Device to map tensors to.

    Returns:
        Dictionary with checkpoint contents.
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)

    if "model" in checkpoint:
        if hasattr(model, "module"):
            model.module.load_state_dict(checkpoint["model"])
        else:
            model.load_state_dict(checkpoint["model"])

    if ema_model is not None and "ema" in checkpoint:
        ema_model.load_state_dict(checkpoint["ema"])

    if optimizer is not None and "opt" in checkpoint:
        optimizer.load_state_dict(checkpoint["opt"])

    return checkpoint

plot_2d_energy_landscape(model, x_range=(-5, 5), y_range=(-5, 5), resolution=100, log_scale=False, cmap='viridis', title=None, show_colorbar=True, save_path=None, fig_size=(8, 6), contour=True, contour_levels=20, device=None)

Plots a 2D energy landscape of a model.

Parameters:

Name Type Description Default
model BaseModel

The model to visualize.

required
x_range Tuple[float, float]

The range for the x-axis.

(-5, 5)
y_range Tuple[float, float]

The range for the y-axis.

(-5, 5)
resolution int

The number of points in each dimension.

100
log_scale bool

Whether to use a log scale for the energy values.

False
cmap str

The colormap to use.

'viridis'
title Optional[str]

The title of the plot.

None
show_colorbar bool

Whether to show a colorbar.

True
save_path Optional[str]

The path to save the figure.

None
fig_size Tuple[int, int]

The size of the figure.

(8, 6)
contour bool

Whether to overlay contour lines.

True
contour_levels int

The number of contour levels.

20
device Optional[str]

The device to use for computation.

None

Returns:

Type Description
Figure

plt.Figure: The matplotlib figure object.

Source code in torchebm/utils/visualization.py
def plot_2d_energy_landscape(
    model: BaseModel,
    x_range: Tuple[float, float] = (-5, 5),
    y_range: Tuple[float, float] = (-5, 5),
    resolution: int = 100,
    log_scale: bool = False,
    cmap: str = "viridis",
    title: Optional[str] = None,
    show_colorbar: bool = True,
    save_path: Optional[str] = None,
    fig_size: Tuple[int, int] = (8, 6),
    contour: bool = True,
    contour_levels: int = 20,
    device: Optional[str] = None,
) -> plt.Figure:
    """
    Plots a 2D energy landscape of a model.

    Args:
        model (BaseModel): The model to visualize.
        x_range (Tuple[float, float]): The range for the x-axis.
        y_range (Tuple[float, float]): The range for the y-axis.
        resolution (int): The number of points in each dimension.
        log_scale (bool): Whether to use a log scale for the energy values.
        cmap (str): The colormap to use.
        title (Optional[str]): The title of the plot.
        show_colorbar (bool): Whether to show a colorbar.
        save_path (Optional[str]): The path to save the figure.
        fig_size (Tuple[int, int]): The size of the figure.
        contour (bool): Whether to overlay contour lines.
        contour_levels (int): The number of contour levels.
        device (Optional[str]): The device to use for computation.

    Returns:
        plt.Figure: The matplotlib figure object.
    """
    # Create the grid
    x = np.linspace(x_range[0], x_range[1], resolution)
    y = np.linspace(y_range[0], y_range[1], resolution)
    X, Y = np.meshgrid(x, y)

    # Convert to pytorch tensor
    grid = torch.tensor(
        np.stack([X.flatten(), Y.flatten()], axis=1), dtype=torch.float32
    )
    if device is not None:
        grid = grid.to(device)
        model = model.to(device)

    # Compute energy values
    with torch.no_grad():
        Z = model(grid).cpu().numpy()
    Z = Z.reshape(X.shape)

    # Apply log scale if requested
    if log_scale:
        # Add a small constant to avoid log(0)
        Z = np.log(Z + 1e-10)

    # Create the figure
    fig, ax = plt.subplots(figsize=fig_size)

    # Plot the surface
    im = ax.pcolormesh(X, Y, Z, cmap=cmap, shading="auto")

    # Overlay contour lines if requested
    if contour:
        contour_plot = ax.contour(
            X, Y, Z, levels=contour_levels, colors="white", alpha=0.5, linewidths=0.5
        )

    # Add colorbar if requested
    if show_colorbar:
        fig.colorbar(im, ax=ax, label="Energy" if not log_scale else "Log Energy")

    # Set title and labels
    if title:
        ax.set_title(title)
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    # Save figure if requested
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig

plot_3d_energy_landscape(model, x_range=(-5, 5), y_range=(-5, 5), resolution=50, log_scale=False, cmap='viridis', title=None, show_colorbar=True, save_path=None, fig_size=(10, 8), alpha=0.9, elev=30, azim=-45, device=None)

Plots a 3D surface visualization of a 2D energy landscape.

Parameters:

Name Type Description Default
model BaseModel

The model to visualize.

required
x_range Tuple[float, float]

The range for the x-axis.

(-5, 5)
y_range Tuple[float, float]

The range for the y-axis.

(-5, 5)
resolution int

The number of points in each dimension.

50
log_scale bool

Whether to use a log scale for the energy values.

False
cmap str

The colormap to use.

'viridis'
title Optional[str]

The title of the plot.

None
show_colorbar bool

Whether to show a colorbar.

True
save_path Optional[str]

The path to save the figure.

None
fig_size Tuple[int, int]

The size of the figure.

(10, 8)
alpha float

The transparency of the surface.

0.9
elev float

The elevation angle for the 3D view.

30
azim float

The azimuth angle for the 3D view.

-45
device Optional[str]

The device to use for computation.

None

Returns:

Type Description
Figure

plt.Figure: The matplotlib figure object.

Source code in torchebm/utils/visualization.py
def plot_3d_energy_landscape(
    model: BaseModel,
    x_range: Tuple[float, float] = (-5, 5),
    y_range: Tuple[float, float] = (-5, 5),
    resolution: int = 50,
    log_scale: bool = False,
    cmap: str = "viridis",
    title: Optional[str] = None,
    show_colorbar: bool = True,
    save_path: Optional[str] = None,
    fig_size: Tuple[int, int] = (10, 8),
    alpha: float = 0.9,
    elev: float = 30,
    azim: float = -45,
    device: Optional[str] = None,
) -> plt.Figure:
    """
    Plots a 3D surface visualization of a 2D energy landscape.

    Args:
        model (BaseModel): The model to visualize.
        x_range (Tuple[float, float]): The range for the x-axis.
        y_range (Tuple[float, float]): The range for the y-axis.
        resolution (int): The number of points in each dimension.
        log_scale (bool): Whether to use a log scale for the energy values.
        cmap (str): The colormap to use.
        title (Optional[str]): The title of the plot.
        show_colorbar (bool): Whether to show a colorbar.
        save_path (Optional[str]): The path to save the figure.
        fig_size (Tuple[int, int]): The size of the figure.
        alpha (float): The transparency of the surface.
        elev (float): The elevation angle for the 3D view.
        azim (float): The azimuth angle for the 3D view.
        device (Optional[str]): The device to use for computation.

    Returns:
        plt.Figure: The matplotlib figure object.
    """
    # Create the grid
    x = np.linspace(x_range[0], x_range[1], resolution)
    y = np.linspace(y_range[0], y_range[1], resolution)
    X, Y = np.meshgrid(x, y)

    # Convert to pytorch tensor
    grid = torch.tensor(
        np.stack([X.flatten(), Y.flatten()], axis=1), dtype=torch.float32
    )
    if device is not None:
        grid = grid.to(device)
        model = model.to(device)

    # Compute energy values
    with torch.no_grad():
        Z = model(grid).cpu().numpy()
    Z = Z.reshape(X.shape)

    # Apply log scale if requested
    if log_scale:
        # Add a small constant to avoid log(0)
        Z = np.log(Z + 1e-10)

    # Create the figure
    fig = plt.figure(figsize=fig_size)
    ax = fig.add_subplot(111, projection="3d")

    # Plot the surface
    surf = ax.plot_surface(
        X, Y, Z, cmap=cmap, alpha=alpha, linewidth=0, antialiased=True
    )

    # Add colorbar if requested
    if show_colorbar:
        fig.colorbar(
            surf,
            ax=ax,
            shrink=0.5,
            aspect=5,
            label="Energy" if not log_scale else "Log Energy",
        )

    # Set title and labels
    if title:
        ax.set_title(title)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("Energy" if not log_scale else "Log Energy")

    # Set view angle
    ax.view_init(elev=elev, azim=azim)

    # Save figure if requested
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig

plot_sample_trajectories(trajectories, model=None, x_range=None, y_range=None, resolution=100, log_scale=False, cmap='viridis', title=None, show_colorbar=True, save_path=None, fig_size=(8, 6), trajectory_colors=None, trajectory_alpha=0.7, line_width=1.0, device=None)

Plots sample trajectories, optionally on an energy landscape background.

Parameters:

Name Type Description Default
trajectories Tensor

A tensor of trajectories of shape (n_chains, n_steps, 2).

required
model Optional[BaseModel]

The model to visualize as a background.

None
x_range Optional[Tuple[float, float]]

The range for the x-axis. If None, it is inferred from the data.

None
y_range Optional[Tuple[float, float]]

The range for the y-axis. If None, it is inferred from the data.

None
resolution int

The number of points in each dimension for the energy grid.

100
log_scale bool

Whether to use a log scale for the energy values.

False
cmap str

The colormap to use for the energy background.

'viridis'
title Optional[str]

The title of the plot.

None
show_colorbar bool

Whether to show a colorbar.

True
save_path Optional[str]

The path to save the figure.

None
fig_size Tuple[int, int]

The size of the figure.

(8, 6)
trajectory_colors Optional[List[str]]

A list of colors for the trajectories.

None
trajectory_alpha float

The transparency of the trajectory lines.

0.7
line_width float

The width of the trajectory lines.

1.0
device Optional[str]

The device to use for computation.

None

Returns:

Type Description
Figure

plt.Figure: The matplotlib figure object.

Source code in torchebm/utils/visualization.py
def plot_sample_trajectories(
    trajectories: torch.Tensor,
    model: Optional[BaseModel] = None,
    x_range: Tuple[float, float] = None,
    y_range: Tuple[float, float] = None,
    resolution: int = 100,
    log_scale: bool = False,
    cmap: str = "viridis",
    title: Optional[str] = None,
    show_colorbar: bool = True,
    save_path: Optional[str] = None,
    fig_size: Tuple[int, int] = (8, 6),
    trajectory_colors: Optional[List[str]] = None,
    trajectory_alpha: float = 0.7,
    line_width: float = 1.0,
    device: Optional[str] = None,
) -> plt.Figure:
    """
    Plots sample trajectories, optionally on an energy landscape background.

    Args:
        trajectories (torch.Tensor): A tensor of trajectories of shape `(n_chains, n_steps, 2)`.
        model (Optional[BaseModel]): The model to visualize as a background.
        x_range (Optional[Tuple[float, float]]): The range for the x-axis. If `None`, it is
            inferred from the data.
        y_range (Optional[Tuple[float, float]]): The range for the y-axis. If `None`, it is
            inferred from the data.
        resolution (int): The number of points in each dimension for the energy grid.
        log_scale (bool): Whether to use a log scale for the energy values.
        cmap (str): The colormap to use for the energy background.
        title (Optional[str]): The title of the plot.
        show_colorbar (bool): Whether to show a colorbar.
        save_path (Optional[str]): The path to save the figure.
        fig_size (Tuple[int, int]): The size of the figure.
        trajectory_colors (Optional[List[str]]): A list of colors for the trajectories.
        trajectory_alpha (float): The transparency of the trajectory lines.
        line_width (float): The width of the trajectory lines.
        device (Optional[str]): The device to use for computation.

    Returns:
        plt.Figure: The matplotlib figure object.
    """
    # Determine plotting ranges if not provided
    if x_range is None or y_range is None:
        all_data = trajectories.detach().cpu().numpy().reshape(-1, 2)
        data_min = all_data.min(axis=0)
        data_max = all_data.max(axis=0)
        padding = (data_max - data_min) * 0.1  # Add 10% padding

        if x_range is None:
            x_range = (data_min[0] - padding[0], data_max[0] + padding[0])
        if y_range is None:
            y_range = (data_min[1] - padding[1], data_max[1] + padding[1])

    # Create figure
    fig, ax = plt.subplots(figsize=fig_size)

    # Plot energy landscape if provided
    if model is not None:
        # Create the grid
        x = np.linspace(x_range[0], x_range[1], resolution)
        y = np.linspace(y_range[0], y_range[1], resolution)
        X, Y = np.meshgrid(x, y)

        # Convert to pytorch tensor
        grid = torch.tensor(
            np.stack([X.flatten(), Y.flatten()], axis=1), dtype=torch.float32
        )
        if device is not None:
            grid = grid.to(device)
            model = model.to(device)

        # Compute energy values
        with torch.no_grad():
            Z = model(grid).cpu().numpy()
        Z = Z.reshape(X.shape)

        # Apply log scale if requested
        if log_scale:
            # Add a small constant to avoid log(0)
            Z = np.log(Z + 1e-10)

        # Plot the surface
        im = ax.pcolormesh(X, Y, Z, cmap=cmap, shading="auto")

        # Add colorbar if requested
        if show_colorbar:
            fig.colorbar(im, ax=ax, label="Energy" if not log_scale else "Log Energy")

    # Plot trajectories
    n_chains = trajectories.shape[0]
    if trajectory_colors is None:
        trajectory_colors = plt.cm.tab10(np.linspace(0, 1, n_chains))

    trajectories_np = trajectories.detach().cpu().numpy()
    for i, trajectory in enumerate(trajectories_np):
        color = trajectory_colors[i] if i < len(trajectory_colors) else "gray"
        ax.plot(
            trajectory[:, 0],
            trajectory[:, 1],
            color=color,
            alpha=trajectory_alpha,
            linewidth=line_width,
        )
        ax.scatter(
            trajectory[0, 0],
            trajectory[0, 1],
            color=color,
            marker="o",
            s=30,
            label=f"Start {i+1}",
        )
        ax.scatter(
            trajectory[-1, 0],
            trajectory[-1, 1],
            color=color,
            marker="x",
            s=50,
            label=f"End {i+1}",
        )

    # Set title and labels
    if title:
        ax.set_title(title)
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    # Set limits
    ax.set_xlim(x_range)
    ax.set_ylim(y_range)

    # Add legend for first chain only to avoid cluttering
    handles, labels = ax.get_legend_handles_labels()
    if len(handles) > 0:
        # Only show legend for the first trajectory's start/end
        ax.legend(handles[:2], labels[:2], loc="best")

    # Save figure if requested
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig

plot_samples_on_energy(model, samples, x_range=(-5, 5), y_range=(-5, 5), resolution=100, log_scale=False, cmap='viridis', title=None, show_colorbar=True, save_path=None, fig_size=(8, 6), contour=True, contour_levels=20, sample_color='red', sample_alpha=0.5, sample_size=5, device=None)

Plots samples on a 2D energy landscape.

Parameters:

Name Type Description Default
model BaseModel

The model to visualize.

required
samples Tensor

A tensor of samples of shape (n_samples, 2).

required
x_range Tuple[float, float]

The range for the x-axis.

(-5, 5)
y_range Tuple[float, float]

The range for the y-axis.

(-5, 5)
resolution int

The number of points in each dimension.

100
log_scale bool

Whether to use a log scale for the energy values.

False
cmap str

The colormap to use.

'viridis'
title Optional[str]

The title of the plot.

None
show_colorbar bool

Whether to show a colorbar.

True
save_path Optional[str]

The path to save the figure.

None
fig_size Tuple[int, int]

The size of the figure.

(8, 6)
contour bool

Whether to overlay contour lines.

True
contour_levels int

The number of contour levels.

20
sample_color str

The color of the samples.

'red'
sample_alpha float

The transparency of the samples.

0.5
sample_size float

The size of the sample markers.

5
device Optional[str]

The device to use for computation.

None

Returns:

Type Description
Figure

plt.Figure: The matplotlib figure object.

Source code in torchebm/utils/visualization.py
def plot_samples_on_energy(
    model: BaseModel,
    samples: torch.Tensor,
    x_range: Tuple[float, float] = (-5, 5),
    y_range: Tuple[float, float] = (-5, 5),
    resolution: int = 100,
    log_scale: bool = False,
    cmap: str = "viridis",
    title: Optional[str] = None,
    show_colorbar: bool = True,
    save_path: Optional[str] = None,
    fig_size: Tuple[int, int] = (8, 6),
    contour: bool = True,
    contour_levels: int = 20,
    sample_color: str = "red",
    sample_alpha: float = 0.5,
    sample_size: float = 5,
    device: Optional[str] = None,
) -> plt.Figure:
    """
    Plots samples on a 2D energy landscape.

    Args:
        model (BaseModel): The model to visualize.
        samples (torch.Tensor): A tensor of samples of shape `(n_samples, 2)`.
        x_range (Tuple[float, float]): The range for the x-axis.
        y_range (Tuple[float, float]): The range for the y-axis.
        resolution (int): The number of points in each dimension.
        log_scale (bool): Whether to use a log scale for the energy values.
        cmap (str): The colormap to use.
        title (Optional[str]): The title of the plot.
        show_colorbar (bool): Whether to show a colorbar.
        save_path (Optional[str]): The path to save the figure.
        fig_size (Tuple[int, int]): The size of the figure.
        contour (bool): Whether to overlay contour lines.
        contour_levels (int): The number of contour levels.
        sample_color (str): The color of the samples.
        sample_alpha (float): The transparency of the samples.
        sample_size (float): The size of the sample markers.
        device (Optional[str]): The device to use for computation.

    Returns:
        plt.Figure: The matplotlib figure object.
    """
    fig = plot_2d_energy_landscape(
        model=model,
        x_range=x_range,
        y_range=y_range,
        resolution=resolution,
        log_scale=log_scale,
        cmap=cmap,
        title=title,
        show_colorbar=show_colorbar,
        fig_size=fig_size,
        contour=contour,
        contour_levels=contour_levels,
        device=device,
    )

    # Get the current axis
    ax = plt.gca()

    # Plot the samples
    samples_np = samples.detach().cpu().numpy()
    ax.scatter(
        samples_np[:, 0],
        samples_np[:, 1],
        color=sample_color,
        alpha=sample_alpha,
        s=sample_size,
    )

    # Save figure if requested
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")

    return fig

requires_grad(model, flag=True)

Set requires_grad flag for all model parameters.

Parameters:

Name Type Description Default
model Module

Model to modify.

required
flag bool

Whether parameters require gradients.

True
Source code in torchebm/utils/training.py
def requires_grad(model: nn.Module, flag: bool = True) -> None:
    r"""Set requires_grad flag for all model parameters.

    Args:
        model: Model to modify.
        flag: Whether parameters require gradients.
    """
    for p in model.parameters():
        p.requires_grad = flag

save_checkpoint(model, optimizer, step, checkpoint_dir, ema_model=None, args=None)

Save training checkpoint.

Parameters:

Name Type Description Default
model Module

Model to save.

required
optimizer Optimizer

Optimizer state.

required
step int

Current training step.

required
checkpoint_dir str

Directory for checkpoints.

required
ema_model Optional[Module]

EMA model (optional).

None
args Optional[Dict[str, Any]]

Additional arguments to save.

None

Returns:

Type Description
str

Path to saved checkpoint.

Source code in torchebm/utils/training.py
def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    step: int,
    checkpoint_dir: str,
    ema_model: Optional[nn.Module] = None,
    args: Optional[Dict[str, Any]] = None,
) -> str:
    r"""Save training checkpoint.

    Args:
        model: Model to save.
        optimizer: Optimizer state.
        step: Current training step.
        checkpoint_dir: Directory for checkpoints.
        ema_model: EMA model (optional).
        args: Additional arguments to save.

    Returns:
        Path to saved checkpoint.
    """
    checkpoint = {
        "model": (
            model.module.state_dict()
            if hasattr(model, "module")
            else model.state_dict()
        ),
        "opt": optimizer.state_dict(),
        "step": step,
    }

    if ema_model is not None:
        checkpoint["ema"] = ema_model.state_dict()

    if args is not None:
        checkpoint["args"] = args

    Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
    checkpoint_path = f"{checkpoint_dir}/{step:07d}.pt"
    torch.save(checkpoint, checkpoint_path)
    return checkpoint_path

update_ema(ema_model, model, decay=0.9999)

Update EMA model parameters.

Parameters:

Name Type Description Default
ema_model Module

Exponential moving average model.

required
model Module

Current model.

required
decay float

EMA decay rate.

0.9999
Source code in torchebm/utils/training.py
@torch.no_grad()
def update_ema(ema_model: nn.Module, model: nn.Module, decay: float = 0.9999) -> None:
    r"""Update EMA model parameters.

    Args:
        ema_model: Exponential moving average model.
        model: Current model.
        decay: EMA decay rate.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        if name in ema_params:
            ema_param = ema_params[name]
            if ema_param.device != param.device:
                ema_param.data = ema_param.data.to(param.device)
            ema_param.mul_(decay).add_(param.data, alpha=1 - decay)