Skip to content

ContrastiveDivergence

Methods and Attributes

Bases: BaseContrastiveDivergence

Implementation of the standard Contrastive Divergence (CD-k) algorithm.

Contrastive Divergence approximates the gradient of the log-likelihood by comparing the energy between real data samples and samples generated after k steps of MCMC initialized from the data samples (or from noise in persistent mode).

The CD loss is defined as:

\[\mathcal{L}_{CD} = \mathbb{E}_{p_{data}}[E_\theta(x)] - \mathbb{E}_{p_k}[E_\theta(x')]\]

where:

  • \(E_\theta(x)\) is the energy function with parameters \(\theta\)
  • \(p_{data}\) is the data distribution
  • \(p_k\) is the distribution after \(k\) steps of MCMC

Algorithm Overview

  1. For non-persistent CD: a. Start MCMC chains from real data samples b. Run MCMC for k steps to generate negative samples c. Compute gradient comparing real and negative samples

  2. For persistent CD: a. Maintain a set of persistent chains between updates b. Continue chains from previous state for k steps c. Update the persistent state for next iteration

Parameters:

Name Type Description Default
energy_function BaseEnergyFunction

Energy function to train

required
sampler BaseSampler

MCMC sampler for generating negative samples

required
k_steps int

Number of MCMC steps (k in CD-k)

10
persistent bool

Whether to use persistent Contrastive Divergence

False
buffer_size int

Size of buffer for PCD. Defaults to 10000.

10000
init_steps int

Number of initial MCMC steps to warm up buffer. Defaults to 100.

100
new_sample_ratio float

Fraction of new random samples to introduce. Defaults to 0.05.

0.05
energy_reg_weight float

Weight for energy regularization. Defaults to 0.001.

0.001
use_temperature_annealing bool

Whether to use temperature annealing for sampler. Defaults to False.

False
min_temp float

Minimum temperature for annealing. Defaults to 0.01.

0.01
max_temp float

Maximum temperature for annealing. Defaults to 2.0.

2.0
temp_decay float

Decay rate for temperature annealing. Defaults to 0.999.

0.999
dtype dtype

Data type for computations

float32
device device

Device to run computations on

device('cpu')

Basic Usage

# Setup energy function, sampler and CD loss
energy_fn = MLPEnergyFunction(input_dim=2, hidden_dim=64)
sampler = LangevinDynamics(energy_fn, step_size=0.1)
cd_loss = ContrastiveDivergence(
    energy_function=energy_fn,
    sampler=sampler,
    k_steps=10,
    persistent=False
)

# In training loop
optimizer = torch.optim.Adam(energy_fn.parameters(), lr=0.001)

for batch in dataloader:
    optimizer.zero_grad()
    loss, _ = cd_loss(batch)
    loss.backward()
    optimizer.step()

Persistent vs Standard CD

  • Standard CD (persistent=False) is more stable but can struggle with complex distributions
  • Persistent CD (persistent=True) can explore better but may require careful initialization
Source code in torchebm/losses/contrastive_divergence.py
class ContrastiveDivergence(BaseContrastiveDivergence):
    r"""
    Implementation of the standard Contrastive Divergence (CD-k) algorithm.

    Contrastive Divergence approximates the gradient of the log-likelihood by comparing
    the energy between real data samples and samples generated after k steps of MCMC
    initialized from the data samples (or from noise in persistent mode).

    The CD loss is defined as:

    $$\mathcal{L}_{CD} = \mathbb{E}_{p_{data}}[E_\theta(x)] - \mathbb{E}_{p_k}[E_\theta(x')]$$

    where:

    - $E_\theta(x)$ is the energy function with parameters $\theta$
    - $p_{data}$ is the data distribution
    - $p_k$ is the distribution after $k$ steps of MCMC

    !!! note "Algorithm Overview"
        1. For non-persistent CD:
           a. Start MCMC chains from real data samples
           b. Run MCMC for k steps to generate negative samples
           c. Compute gradient comparing real and negative samples

        2. For persistent CD:
           a. Maintain a set of persistent chains between updates
           b. Continue chains from previous state for k steps
           c. Update the persistent state for next iteration

    Args:
        energy_function (BaseEnergyFunction): Energy function to train
        sampler (BaseSampler): MCMC sampler for generating negative samples
        k_steps (int): Number of MCMC steps (k in CD-k)
        persistent (bool): Whether to use persistent Contrastive Divergence
        buffer_size (int, optional): Size of buffer for PCD. Defaults to 10000.
        init_steps (int, optional): Number of initial MCMC steps to warm up buffer. Defaults to 100.
        new_sample_ratio (float, optional): Fraction of new random samples to introduce. Defaults to 0.05.
        energy_reg_weight (float, optional): Weight for energy regularization. Defaults to 0.001.
        use_temperature_annealing (bool, optional): Whether to use temperature annealing for sampler. Defaults to False.
        min_temp (float, optional): Minimum temperature for annealing. Defaults to 0.01.
        max_temp (float, optional): Maximum temperature for annealing. Defaults to 2.0.
        temp_decay (float, optional): Decay rate for temperature annealing. Defaults to 0.999.
        dtype (torch.dtype): Data type for computations
        device (torch.device): Device to run computations on

    !!! example "Basic Usage"
        ```python
        # Setup energy function, sampler and CD loss
        energy_fn = MLPEnergyFunction(input_dim=2, hidden_dim=64)
        sampler = LangevinDynamics(energy_fn, step_size=0.1)
        cd_loss = ContrastiveDivergence(
            energy_function=energy_fn,
            sampler=sampler,
            k_steps=10,
            persistent=False
        )

        # In training loop
        optimizer = torch.optim.Adam(energy_fn.parameters(), lr=0.001)

        for batch in dataloader:
            optimizer.zero_grad()
            loss, _ = cd_loss(batch)
            loss.backward()
            optimizer.step()
        ```

    !!! tip "Persistent vs Standard CD"
        - Standard CD (`persistent=False`) is more stable but can struggle with complex distributions
        - Persistent CD (`persistent=True`) can explore better but may require careful initialization
    """

    def __init__(
        self,
        energy_function,
        sampler,
        k_steps=10,
        persistent=False,
        buffer_size=10000,
        init_steps=100,
        new_sample_ratio=0.05,
        energy_reg_weight=0.001,
        use_temperature_annealing=False,
        min_temp=0.01,
        max_temp=2.0,
        temp_decay=0.999,
        dtype=torch.float32,
        device=torch.device("cpu"),
        *args,
        **kwargs,
    ):
        super().__init__(
            energy_function=energy_function,
            sampler=sampler,
            k_steps=k_steps,
            persistent=persistent,
            buffer_size=buffer_size,
            new_sample_ratio=new_sample_ratio,
            init_steps=init_steps,
            dtype=dtype,
            device=device,
            *args,
            **kwargs,
        )
        # Additional parameters for improved stability
        self.energy_reg_weight = energy_reg_weight
        self.use_temperature_annealing = use_temperature_annealing
        self.min_temp = min_temp
        self.max_temp = max_temp
        self.temp_decay = temp_decay
        self.current_temp = max_temp

        # Register temperature as buffer for persistence
        self.register_buffer(
            "temperature", torch.tensor(max_temp, dtype=self.dtype, device=self.device)
        )

    def forward(
        self, x: torch.Tensor, *args, **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the Contrastive Divergence loss and generate negative samples.

        This method implements the energy_functions CD algorithm by:

        1. Initializing MCMC chains (either from data or persistent state)
        2. Running the sampler for k_steps to generate negative samples
        3. Computing the CD loss using the energy difference

        Args:
            x (torch.Tensor): Batch of real data samples (positive samples)
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - loss: The CD loss value (scalar)
                - pred_samples: Generated negative samples

        !!! note "Shape Information"
            - Input `x`: (batch_size, feature_dimensions)
            - Output loss: scalar
            - Output pred_samples: (batch_size, feature_dimensions)
        """

        batch_size = x.shape[0]
        data_shape = x.shape[1:]

        # Update temperature if annealing is enabled
        if self.use_temperature_annealing and self.training:
            self.current_temp = max(self.min_temp, self.current_temp * self.temp_decay)
            self.temperature[...] = self.current_temp  # Use ellipsis instead of index

            # If sampler has a temperature parameter, update it
            if hasattr(self.sampler, "temperature"):
                self.sampler.temperature = self.current_temp
            elif hasattr(self.sampler, "noise_scale"):
                # For samplers like Langevin, adjust noise scale based on temperature
                original_noise = getattr(self.sampler, "_original_noise_scale", None)
                if original_noise is None:
                    setattr(
                        self.sampler, "_original_noise_scale", self.sampler.noise_scale
                    )
                    original_noise = self.sampler.noise_scale

                self.sampler.noise_scale = original_noise * math.sqrt(self.current_temp)

        # Get starting points for chains (either from buffer or data)
        start_points = self.get_start_points(x)

        # Run MCMC chains to get negative samples
        pred_samples = self.sampler.sample(
            x=start_points,
            n_steps=self.k_steps,
        )

        # Update persistent buffer if using PCD
        if self.persistent:
            with torch.no_grad():
                self.update_buffer(pred_samples.detach())

        # Add energy regularization to kwargs for compute_loss
        kwargs["energy_reg_weight"] = kwargs.get(
            "energy_reg_weight", self.energy_reg_weight
        )

        # Compute contrastive divergence loss
        loss = self.compute_loss(x, pred_samples, *args, **kwargs)

        return loss, pred_samples

    def compute_loss(
        self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
    ) -> torch.Tensor:
        """
        Compute the Contrastive Divergence loss given positive and negative samples.

        The CD loss is defined as the difference between the average energy of positive samples
        (from the data distribution) and the average energy of negative samples (from the model).

        Args:
            x (torch.Tensor): Real data samples (positive samples)
            pred_x (torch.Tensor): Generated negative samples
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments

        Returns:
            Scalar loss value

        !!! warning "Gradient Direction"
            Note that this implementation returns `E(x) - E(x')`, so during optimization
            we *minimize* this value. This is different from some formulations that
            maximize `E(x') - E(x)`.
        """
        # Ensure inputs are on the correct device and dtype
        x = x.to(self.device, self.dtype)
        pred_x = pred_x.to(self.device, self.dtype)

        # Compute energy of real and generated samples
        with torch.set_grad_enabled(True):
            # Add small noise to real data for stability (optional)
            if kwargs.get("add_noise_to_real", False):
                noise_scale = kwargs.get("noise_scale", 1e-4)
                x_noisy = x + noise_scale * torch.randn_like(x)
                x_energy = self.energy_function(x_noisy)
            else:
                x_energy = self.energy_function(x)

            pred_x_energy = self.energy_function(pred_x)

        # Compute mean energies with improved numerical stability
        mean_x_energy = torch.mean(x_energy)
        mean_pred_energy = torch.mean(pred_x_energy)

        # Basic contrastive divergence loss: E[data] - E[model]
        loss = mean_x_energy - mean_pred_energy

        # Optional: Regularization to prevent energies from becoming too large
        # This helps with stability especially in the early phases of training
        energy_reg_weight = kwargs.get("energy_reg_weight", 0.001)
        if energy_reg_weight > 0:
            energy_reg = energy_reg_weight * (
                torch.mean(x_energy**2) + torch.mean(pred_x_energy**2)
            )
            loss = loss + energy_reg

        # Prevent extremely large gradients with a safety check
        if torch.isnan(loss) or torch.isinf(loss):
            warnings.warn(
                f"NaN or Inf detected in CD loss. x_energy: {mean_x_energy}, pred_energy: {mean_pred_energy}",
                RuntimeWarning,
            )
            # Return a small positive constant instead of NaN/Inf to prevent training collapse
            return torch.tensor(0.1, device=self.device, dtype=self.dtype)

        return loss

energy_reg_weight instance-attribute

energy_reg_weight = energy_reg_weight

use_temperature_annealing instance-attribute

use_temperature_annealing = use_temperature_annealing

min_temp instance-attribute

min_temp = min_temp

max_temp instance-attribute

max_temp = max_temp

temp_decay instance-attribute

temp_decay = temp_decay

current_temp instance-attribute

current_temp = max_temp

forward

forward(x: Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]

Compute the Contrastive Divergence loss and generate negative samples.

This method implements the energy_functions CD algorithm by:

  1. Initializing MCMC chains (either from data or persistent state)
  2. Running the sampler for k_steps to generate negative samples
  3. Computing the CD loss using the energy difference

Parameters:

Name Type Description Default
x Tensor

Batch of real data samples (positive samples)

required
*args

Additional positional arguments

()
**kwargs

Additional keyword arguments

{}

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: - loss: The CD loss value (scalar) - pred_samples: Generated negative samples

Shape Information

  • Input x: (batch_size, feature_dimensions)
  • Output loss: scalar
  • Output pred_samples: (batch_size, feature_dimensions)
Source code in torchebm/losses/contrastive_divergence.py
def forward(
    self, x: torch.Tensor, *args, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute the Contrastive Divergence loss and generate negative samples.

    This method implements the energy_functions CD algorithm by:

    1. Initializing MCMC chains (either from data or persistent state)
    2. Running the sampler for k_steps to generate negative samples
    3. Computing the CD loss using the energy difference

    Args:
        x (torch.Tensor): Batch of real data samples (positive samples)
        *args: Additional positional arguments
        **kwargs: Additional keyword arguments

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - loss: The CD loss value (scalar)
            - pred_samples: Generated negative samples

    !!! note "Shape Information"
        - Input `x`: (batch_size, feature_dimensions)
        - Output loss: scalar
        - Output pred_samples: (batch_size, feature_dimensions)
    """

    batch_size = x.shape[0]
    data_shape = x.shape[1:]

    # Update temperature if annealing is enabled
    if self.use_temperature_annealing and self.training:
        self.current_temp = max(self.min_temp, self.current_temp * self.temp_decay)
        self.temperature[...] = self.current_temp  # Use ellipsis instead of index

        # If sampler has a temperature parameter, update it
        if hasattr(self.sampler, "temperature"):
            self.sampler.temperature = self.current_temp
        elif hasattr(self.sampler, "noise_scale"):
            # For samplers like Langevin, adjust noise scale based on temperature
            original_noise = getattr(self.sampler, "_original_noise_scale", None)
            if original_noise is None:
                setattr(
                    self.sampler, "_original_noise_scale", self.sampler.noise_scale
                )
                original_noise = self.sampler.noise_scale

            self.sampler.noise_scale = original_noise * math.sqrt(self.current_temp)

    # Get starting points for chains (either from buffer or data)
    start_points = self.get_start_points(x)

    # Run MCMC chains to get negative samples
    pred_samples = self.sampler.sample(
        x=start_points,
        n_steps=self.k_steps,
    )

    # Update persistent buffer if using PCD
    if self.persistent:
        with torch.no_grad():
            self.update_buffer(pred_samples.detach())

    # Add energy regularization to kwargs for compute_loss
    kwargs["energy_reg_weight"] = kwargs.get(
        "energy_reg_weight", self.energy_reg_weight
    )

    # Compute contrastive divergence loss
    loss = self.compute_loss(x, pred_samples, *args, **kwargs)

    return loss, pred_samples

compute_loss

compute_loss(x: Tensor, pred_x: Tensor, *args, **kwargs) -> torch.Tensor

Compute the Contrastive Divergence loss given positive and negative samples.

The CD loss is defined as the difference between the average energy of positive samples (from the data distribution) and the average energy of negative samples (from the model).

Parameters:

Name Type Description Default
x Tensor

Real data samples (positive samples)

required
pred_x Tensor

Generated negative samples

required
*args

Additional positional arguments

()
**kwargs

Additional keyword arguments

{}

Returns:

Type Description
Tensor

Scalar loss value

Gradient Direction

Note that this implementation returns E(x) - E(x'), so during optimization we minimize this value. This is different from some formulations that maximize E(x') - E(x).

Source code in torchebm/losses/contrastive_divergence.py
def compute_loss(
    self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
    """
    Compute the Contrastive Divergence loss given positive and negative samples.

    The CD loss is defined as the difference between the average energy of positive samples
    (from the data distribution) and the average energy of negative samples (from the model).

    Args:
        x (torch.Tensor): Real data samples (positive samples)
        pred_x (torch.Tensor): Generated negative samples
        *args: Additional positional arguments
        **kwargs: Additional keyword arguments

    Returns:
        Scalar loss value

    !!! warning "Gradient Direction"
        Note that this implementation returns `E(x) - E(x')`, so during optimization
        we *minimize* this value. This is different from some formulations that
        maximize `E(x') - E(x)`.
    """
    # Ensure inputs are on the correct device and dtype
    x = x.to(self.device, self.dtype)
    pred_x = pred_x.to(self.device, self.dtype)

    # Compute energy of real and generated samples
    with torch.set_grad_enabled(True):
        # Add small noise to real data for stability (optional)
        if kwargs.get("add_noise_to_real", False):
            noise_scale = kwargs.get("noise_scale", 1e-4)
            x_noisy = x + noise_scale * torch.randn_like(x)
            x_energy = self.energy_function(x_noisy)
        else:
            x_energy = self.energy_function(x)

        pred_x_energy = self.energy_function(pred_x)

    # Compute mean energies with improved numerical stability
    mean_x_energy = torch.mean(x_energy)
    mean_pred_energy = torch.mean(pred_x_energy)

    # Basic contrastive divergence loss: E[data] - E[model]
    loss = mean_x_energy - mean_pred_energy

    # Optional: Regularization to prevent energies from becoming too large
    # This helps with stability especially in the early phases of training
    energy_reg_weight = kwargs.get("energy_reg_weight", 0.001)
    if energy_reg_weight > 0:
        energy_reg = energy_reg_weight * (
            torch.mean(x_energy**2) + torch.mean(pred_x_energy**2)
        )
        loss = loss + energy_reg

    # Prevent extremely large gradients with a safety check
    if torch.isnan(loss) or torch.isinf(loss):
        warnings.warn(
            f"NaN or Inf detected in CD loss. x_energy: {mean_x_energy}, pred_energy: {mean_pred_energy}",
            RuntimeWarning,
        )
        # Return a small positive constant instead of NaN/Inf to prevent training collapse
        return torch.tensor(0.1, device=self.device, dtype=self.dtype)

    return loss