Skip to content

ContrastiveDivergenceBase

Methods and Attributes

Bases: Loss

Source code in torchebm/losses/contrastive_divergence.py
class ContrastiveDivergenceBase(Loss):
    def __init__(self, k=1):
        super().__init__()
        self.k = k  # Number of sampling steps

    @abstractmethod
    def sample(self, energy_model, x_pos):
        """Abstract method: Generate negative samples from the energy model.
        Args:
            energy_model: Energy-based model (e.g., RBM)
            x_pos: Positive samples (data)
        Returns:
            x_neg: Negative samples (model samples)
        """
        raise NotImplementedError

    def forward(self, energy_model, x_pos):
        """Compute the CD loss: E(x_pos) - E(x_neg)"""
        x_neg = self.sample(energy_model, x_pos)
        loss = energy_model(x_pos).mean() - energy_model(x_neg).mean()
        return loss

k instance-attribute

k = k

sample abstractmethod

sample(energy_model, x_pos)

Abstract method: Generate negative samples from the energy model. Args: energy_model: Energy-based model (e.g., RBM) x_pos: Positive samples (data) Returns: x_neg: Negative samples (model samples)

Source code in torchebm/losses/contrastive_divergence.py
@abstractmethod
def sample(self, energy_model, x_pos):
    """Abstract method: Generate negative samples from the energy model.
    Args:
        energy_model: Energy-based model (e.g., RBM)
        x_pos: Positive samples (data)
    Returns:
        x_neg: Negative samples (model samples)
    """
    raise NotImplementedError

forward

forward(energy_model, x_pos)

Compute the CD loss: E(x_pos) - E(x_neg)

Source code in torchebm/losses/contrastive_divergence.py
def forward(self, energy_model, x_pos):
    """Compute the CD loss: E(x_pos) - E(x_neg)"""
    x_neg = self.sample(energy_model, x_pos)
    loss = energy_model(x_pos).mean() - energy_model(x_neg).mean()
    return loss