Skip to content

torchebm.losses.contrastive_divergence

Contrastive Divergence Loss Module.

ContrastiveDivergence

Bases: BaseContrastiveDivergence

Standard Contrastive Divergence (CD-k) loss.

CD approximates the log-likelihood gradient by running an MCMC sampler for k_steps to generate negative samples.

Parameters:

Name Type Description Default
model

The energy-based model to train.

required
sampler

The MCMC sampler for generating negative samples.

required
k_steps

The number of MCMC steps (k in CD-k).

10
persistent

If True, uses Persistent CD with a replay buffer.

False
buffer_size

Size of the replay buffer for PCD.

10000
init_steps

Number of MCMC steps to warm up the buffer.

100
new_sample_ratio

Fraction of new random samples for PCD chains.

0.05
energy_reg_weight

Weight for energy regularization term.

0.001
use_temperature_annealing

Whether to use temperature annealing.

False
min_temp

Minimum temperature for annealing.

0.01
max_temp

Maximum temperature for annealing.

2.0
temp_decay

Decay rate for temperature annealing.

0.999
dtype

Data type for computations.

float32
device

Device for computations.

device('cpu')
Example
1
2
3
4
5
6
7
8
9
from torchebm.losses import ContrastiveDivergence
from torchebm.samplers import LangevinDynamics
from torchebm.core import DoubleWellEnergy

energy = DoubleWellEnergy()
sampler = LangevinDynamics(energy, step_size=0.01)
cd_loss = ContrastiveDivergence(model=energy, sampler=sampler, k_steps=10)
x = torch.randn(32, 2)
loss, neg_samples = cd_loss(x)
Source code in torchebm/losses/contrastive_divergence.py
class ContrastiveDivergence(BaseContrastiveDivergence):
    r"""
    Standard Contrastive Divergence (CD-k) loss.

    CD approximates the log-likelihood gradient by running an MCMC sampler
    for `k_steps` to generate negative samples.

    Args:
        model: The energy-based model to train.
        sampler: The MCMC sampler for generating negative samples.
        k_steps: The number of MCMC steps (k in CD-k).
        persistent: If True, uses Persistent CD with a replay buffer.
        buffer_size: Size of the replay buffer for PCD.
        init_steps: Number of MCMC steps to warm up the buffer.
        new_sample_ratio: Fraction of new random samples for PCD chains.
        energy_reg_weight: Weight for energy regularization term.
        use_temperature_annealing: Whether to use temperature annealing.
        min_temp: Minimum temperature for annealing.
        max_temp: Maximum temperature for annealing.
        temp_decay: Decay rate for temperature annealing.
        dtype: Data type for computations.
        device: Device for computations.

    Example:
        ```python
        from torchebm.losses import ContrastiveDivergence
        from torchebm.samplers import LangevinDynamics
        from torchebm.core import DoubleWellEnergy

        energy = DoubleWellEnergy()
        sampler = LangevinDynamics(energy, step_size=0.01)
        cd_loss = ContrastiveDivergence(model=energy, sampler=sampler, k_steps=10)
        x = torch.randn(32, 2)
        loss, neg_samples = cd_loss(x)
        ```
    """

    def __init__(
        self,
        model,
        sampler,
        k_steps=10,
        persistent=False,
        buffer_size=10000,
        init_steps=100,
        new_sample_ratio=0.05,
        energy_reg_weight=0.001,
        use_temperature_annealing=False,
        min_temp=0.01,
        max_temp=2.0,
        temp_decay=0.999,
        dtype=torch.float32,
        device=torch.device("cpu"),
        *args,
        **kwargs,
    ):
        super().__init__(
            model=model,
            sampler=sampler,
            k_steps=k_steps,
            persistent=persistent,
            buffer_size=buffer_size,
            new_sample_ratio=new_sample_ratio,
            init_steps=init_steps,
            dtype=dtype,
            device=device,
            *args,
            **kwargs,
        )
        # Additional parameters for improved stability
        self.energy_reg_weight = energy_reg_weight
        self.use_temperature_annealing = use_temperature_annealing
        self.min_temp = min_temp
        self.max_temp = max_temp
        self.temp_decay = temp_decay
        self.current_temp = max_temp

        # Register temperature as buffer for persistence
        self.register_buffer(
            "temperature", torch.tensor(max_temp, dtype=self.dtype, device=self.device)
        )

    def forward(
        self, x: torch.Tensor, *args, **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the Contrastive Divergence loss and generates negative samples.

        Args:
            x (torch.Tensor): A batch of real data samples (positive samples).
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - The scalar CD loss value.
                - The generated negative samples.
        """

        batch_size = x.shape[0]
        data_shape = x.shape[1:]

        # Update temperature if annealing is enabled
        if self.use_temperature_annealing and self.training:
            self.current_temp = max(self.min_temp, self.current_temp * self.temp_decay)
            self.temperature[...] = self.current_temp  # Use ellipsis instead of index

            # If sampler has a temperature parameter, update it
            if hasattr(self.sampler, "temperature"):
                self.sampler.temperature = self.current_temp
            elif hasattr(self.sampler, "noise_scale"):
                # For samplers like Langevin, adjust noise scale based on temperature
                original_noise = getattr(self.sampler, "_original_noise_scale", None)
                if original_noise is None:
                    setattr(
                        self.sampler, "_original_noise_scale", self.sampler.noise_scale
                    )
                    original_noise = self.sampler.noise_scale

                self.sampler.noise_scale = original_noise * math.sqrt(self.current_temp)

        # Get starting points for chains (either from buffer or data)
        start_points = self.get_start_points(x)

        # Run MCMC chains to get negative samples
        pred_samples = self.sampler.sample(
            x=start_points,
            n_steps=self.k_steps,
        )

        # Update persistent buffer if using PCD
        if self.persistent:
            with torch.no_grad():
                self.update_buffer(pred_samples.detach())

        # Add energy regularization to kwargs for compute_loss
        kwargs["energy_reg_weight"] = kwargs.get(
            "energy_reg_weight", self.energy_reg_weight
        )

        # Compute contrastive divergence loss
        loss = self.compute_loss(x, pred_samples, *args, **kwargs)

        return loss, pred_samples

    def compute_loss(
        self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
    ) -> torch.Tensor:
        """
        Computes the Contrastive Divergence loss from positive and negative samples.

        The loss is the difference between the mean energy of positive samples
        and the mean energy of negative samples.

        Args:
            x (torch.Tensor): Real data samples (positive samples).
            pred_x (torch.Tensor): Generated negative samples.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The scalar loss value.
        """
        # Ensure inputs are on the correct device and dtype
        x = x.to(self.device, self.dtype)
        pred_x = pred_x.to(self.device, self.dtype)

        # Compute energy of real and generated samples
        with torch.set_grad_enabled(True):
            # Add small noise to real data for stability (optional)
            if kwargs.get("add_noise_to_real", False):
                noise_scale = kwargs.get("noise_scale", 1e-4)
                x_noisy = x + noise_scale * torch.randn_like(x)
                x_energy = self.model(x_noisy)
            else:
                x_energy = self.model(x)

            pred_x_energy = self.model(pred_x)

        # Compute mean energies with improved numerical stability
        mean_x_energy = torch.mean(x_energy)
        mean_pred_energy = torch.mean(pred_x_energy)

        # Basic contrastive divergence loss: E[data] - E[model]
        loss = mean_x_energy - mean_pred_energy

        # Optional: Regularization to prevent energies from becoming too large
        # This helps with stability especially in the early phases of training
        energy_reg_weight = kwargs.get("energy_reg_weight", 0.001)
        if energy_reg_weight > 0:
            energy_reg = energy_reg_weight * (
                torch.mean(x_energy**2) + torch.mean(pred_x_energy**2)
            )
            loss = loss + energy_reg

        # Prevent extremely large gradients with a safety check
        if torch.isnan(loss) or torch.isinf(loss):
            warnings.warn(
                f"NaN or Inf detected in CD loss. x_energy: {mean_x_energy}, pred_energy: {mean_pred_energy}",
                RuntimeWarning,
            )
            # Return a small positive constant instead of NaN/Inf to prevent training collapse
            return torch.tensor(0.1, device=self.device, dtype=self.dtype)

        return loss

compute_loss(x, pred_x, *args, **kwargs)

Computes the Contrastive Divergence loss from positive and negative samples.

The loss is the difference between the mean energy of positive samples and the mean energy of negative samples.

Parameters:

Name Type Description Default
x Tensor

Real data samples (positive samples).

required
pred_x Tensor

Generated negative samples.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

torch.Tensor: The scalar loss value.

Source code in torchebm/losses/contrastive_divergence.py
def compute_loss(
    self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
    """
    Computes the Contrastive Divergence loss from positive and negative samples.

    The loss is the difference between the mean energy of positive samples
    and the mean energy of negative samples.

    Args:
        x (torch.Tensor): Real data samples (positive samples).
        pred_x (torch.Tensor): Generated negative samples.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        torch.Tensor: The scalar loss value.
    """
    # Ensure inputs are on the correct device and dtype
    x = x.to(self.device, self.dtype)
    pred_x = pred_x.to(self.device, self.dtype)

    # Compute energy of real and generated samples
    with torch.set_grad_enabled(True):
        # Add small noise to real data for stability (optional)
        if kwargs.get("add_noise_to_real", False):
            noise_scale = kwargs.get("noise_scale", 1e-4)
            x_noisy = x + noise_scale * torch.randn_like(x)
            x_energy = self.model(x_noisy)
        else:
            x_energy = self.model(x)

        pred_x_energy = self.model(pred_x)

    # Compute mean energies with improved numerical stability
    mean_x_energy = torch.mean(x_energy)
    mean_pred_energy = torch.mean(pred_x_energy)

    # Basic contrastive divergence loss: E[data] - E[model]
    loss = mean_x_energy - mean_pred_energy

    # Optional: Regularization to prevent energies from becoming too large
    # This helps with stability especially in the early phases of training
    energy_reg_weight = kwargs.get("energy_reg_weight", 0.001)
    if energy_reg_weight > 0:
        energy_reg = energy_reg_weight * (
            torch.mean(x_energy**2) + torch.mean(pred_x_energy**2)
        )
        loss = loss + energy_reg

    # Prevent extremely large gradients with a safety check
    if torch.isnan(loss) or torch.isinf(loss):
        warnings.warn(
            f"NaN or Inf detected in CD loss. x_energy: {mean_x_energy}, pred_energy: {mean_pred_energy}",
            RuntimeWarning,
        )
        # Return a small positive constant instead of NaN/Inf to prevent training collapse
        return torch.tensor(0.1, device=self.device, dtype=self.dtype)

    return loss

forward(x, *args, **kwargs)

Computes the Contrastive Divergence loss and generates negative samples.

Parameters:

Name Type Description Default
x Tensor

A batch of real data samples (positive samples).

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: - The scalar CD loss value. - The generated negative samples.

Source code in torchebm/losses/contrastive_divergence.py
def forward(
    self, x: torch.Tensor, *args, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the Contrastive Divergence loss and generates negative samples.

    Args:
        x (torch.Tensor): A batch of real data samples (positive samples).
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - The scalar CD loss value.
            - The generated negative samples.
    """

    batch_size = x.shape[0]
    data_shape = x.shape[1:]

    # Update temperature if annealing is enabled
    if self.use_temperature_annealing and self.training:
        self.current_temp = max(self.min_temp, self.current_temp * self.temp_decay)
        self.temperature[...] = self.current_temp  # Use ellipsis instead of index

        # If sampler has a temperature parameter, update it
        if hasattr(self.sampler, "temperature"):
            self.sampler.temperature = self.current_temp
        elif hasattr(self.sampler, "noise_scale"):
            # For samplers like Langevin, adjust noise scale based on temperature
            original_noise = getattr(self.sampler, "_original_noise_scale", None)
            if original_noise is None:
                setattr(
                    self.sampler, "_original_noise_scale", self.sampler.noise_scale
                )
                original_noise = self.sampler.noise_scale

            self.sampler.noise_scale = original_noise * math.sqrt(self.current_temp)

    # Get starting points for chains (either from buffer or data)
    start_points = self.get_start_points(x)

    # Run MCMC chains to get negative samples
    pred_samples = self.sampler.sample(
        x=start_points,
        n_steps=self.k_steps,
    )

    # Update persistent buffer if using PCD
    if self.persistent:
        with torch.no_grad():
            self.update_buffer(pred_samples.detach())

    # Add energy regularization to kwargs for compute_loss
    kwargs["energy_reg_weight"] = kwargs.get(
        "energy_reg_weight", self.energy_reg_weight
    )

    # Compute contrastive divergence loss
    loss = self.compute_loss(x, pred_samples, *args, **kwargs)

    return loss, pred_samples