Skip to content

BaseContrastiveDivergence

Methods and Attributes

Bases: BaseLoss

Abstract base class for Contrastive Divergence (CD) based loss functions.

Contrastive Divergence is a family of methods for training energy-based models that approximate the gradient of the log-likelihood by comparing the energy between real data samples (positive phase) and model samples (negative phase) generated through MCMC sampling.

This class provides the common structure for CD variants, including standard CD, Persistent CD (PCD), and others.

Source code in torchebm/core/base_loss.py
class BaseContrastiveDivergence(BaseLoss):
    """
    Abstract base class for Contrastive Divergence (CD) based loss functions.

    Contrastive Divergence is a family of methods for training energy-based models that
    approximate the gradient of the log-likelihood by comparing the energy between real
    data samples (positive phase) and model samples (negative phase) generated through
    MCMC sampling.

    This class provides the common structure for CD variants, including standard CD,
    Persistent CD (PCD), and others.

    Attributes:
        energy_function: The energy function being trained
        sampler: MCMC sampler for generating negative samples
        n_steps: Number of MCMC steps to perform for each update
        persistent: Whether to use persistent chains (PCD)
        dtype: Data type for computations
        device: Device for computations
        chain: Buffer for persistent chains (when using PCD)
    """

    def __init__(
        self,
        energy_function: BaseEnergyFunction,
        sampler: BaseSampler,
        n_steps: int = 1,
        persistent: bool = False,
        dtype: torch.dtype = torch.float32,
        device: Optional[Union[str, torch.device]] = None,
        *args,
        **kwargs,
    ):
        """
        Initialize the ContrastiveDivergence loss.

        Args:
            energy_function: Energy function to train
            sampler: MCMC sampler for generating negative samples
            n_steps: Number of MCMC steps for generating negative samples
            persistent: Whether to use persistent CD (maintain chains between updates)
            dtype: Data type for computations
            device: Device for computations
            *args: Additional positional arguments
            **kwargs: Additional keyword arguments
        """
        super().__init__()
        self.energy_function = energy_function
        self.sampler = sampler
        self.n_steps = n_steps
        self.persistent = persistent
        self.dtype = dtype
        self.device = device or (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        # Register buffer for persistent chains (if using PCD)
        self.register_buffer("chain", None)  # For persistent CD

    def __call__(self, x, *args, **kwargs):
        """
        Call the forward method of the loss function.

        Args:
            x: Real data samples (positive samples).
            *args: Positional arguments.
            **kwargs: Keyword arguments.

        Returns:
            torch.Tensor: The computed loss.
        """
        return self.forward(x, *args, **kwargs)

    @abstractmethod
    def forward(
        self, x: torch.Tensor, *args, **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute CD loss given real data samples.

        This method should implement the specifics of the contrastive divergence
        variant, typically:
        1. Generate negative samples using the MCMC sampler
        2. Compute energies for real and negative samples
        3. Calculate the contrastive loss

        Args:
            x: Real data samples (positive samples).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - loss: The contrastive divergence loss
                - pred_x: Generated negative samples
        """
        pass

    def initialize_persistent_chain(self, shape: Tuple[int, ...]) -> torch.Tensor:
        """
        Initialize the persistent chain with random noise.

        For persistent CD variants, this method initializes the persistent chain
        buffer with random noise. This is typically called the first time the loss
        is computed or when the batch size changes.

        Args:
            shape: Shape of the initial chain state.

        Returns:
            The initialized chain.
        """

        if self.chain is None or self.chain.shape != shape:
            self.chain = torch.randn(*shape, dtype=self.dtype, device=self.device)

        return self.chain

    @abstractmethod
    def compute_loss(
        self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
    ) -> torch.Tensor:
        """
        Compute the contrastive divergence loss from positive and negative samples.

        This method defines how the loss is calculated given real samples (positive phase)
        and samples from the model (negative phase). Typical implementations compute
        the difference between mean energies of positive and negative samples.

        Args:
            x: Real data samples (positive samples).
            pred_x: Generated negative samples.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The contrastive divergence loss
        """
        pass

    def to(self, device: Union[str, torch.device]) -> "BaseContrastiveDivergence":
        """
        Move loss to specified device.

        Args:
            device: Target device for computations.

        Returns:
            The loss function instance moved to the specified device.
        """
        self.device = device
        return self

energy_function instance-attribute

energy_function = energy_function

sampler instance-attribute

sampler = sampler

n_steps instance-attribute

n_steps = n_steps

persistent instance-attribute

persistent = persistent

dtype instance-attribute

dtype = dtype

device instance-attribute

device = device or device('cuda') if is_available() else device('cpu')

forward abstractmethod

forward(x: Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]

Compute CD loss given real data samples.

This method should implement the specifics of the contrastive divergence variant, typically: 1. Generate negative samples using the MCMC sampler 2. Compute energies for real and negative samples 3. Calculate the contrastive loss

Parameters:

Name Type Description Default
x Tensor

Real data samples (positive samples).

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: - loss: The contrastive divergence loss - pred_x: Generated negative samples

Source code in torchebm/core/base_loss.py
@abstractmethod
def forward(
    self, x: torch.Tensor, *args, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute CD loss given real data samples.

    This method should implement the specifics of the contrastive divergence
    variant, typically:
    1. Generate negative samples using the MCMC sampler
    2. Compute energies for real and negative samples
    3. Calculate the contrastive loss

    Args:
        x: Real data samples (positive samples).

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - loss: The contrastive divergence loss
            - pred_x: Generated negative samples
    """
    pass

initialize_persistent_chain

initialize_persistent_chain(shape: Tuple[int, ...]) -> torch.Tensor

Initialize the persistent chain with random noise.

For persistent CD variants, this method initializes the persistent chain buffer with random noise. This is typically called the first time the loss is computed or when the batch size changes.

Parameters:

Name Type Description Default
shape Tuple[int, ...]

Shape of the initial chain state.

required

Returns:

Type Description
Tensor

The initialized chain.

Source code in torchebm/core/base_loss.py
def initialize_persistent_chain(self, shape: Tuple[int, ...]) -> torch.Tensor:
    """
    Initialize the persistent chain with random noise.

    For persistent CD variants, this method initializes the persistent chain
    buffer with random noise. This is typically called the first time the loss
    is computed or when the batch size changes.

    Args:
        shape: Shape of the initial chain state.

    Returns:
        The initialized chain.
    """

    if self.chain is None or self.chain.shape != shape:
        self.chain = torch.randn(*shape, dtype=self.dtype, device=self.device)

    return self.chain

compute_loss abstractmethod

compute_loss(x: Tensor, pred_x: Tensor, *args, **kwargs) -> torch.Tensor

Compute the contrastive divergence loss from positive and negative samples.

This method defines how the loss is calculated given real samples (positive phase) and samples from the model (negative phase). Typical implementations compute the difference between mean energies of positive and negative samples.

Parameters:

Name Type Description Default
x Tensor

Real data samples (positive samples).

required
pred_x Tensor

Generated negative samples.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

torch.Tensor: The contrastive divergence loss

Source code in torchebm/core/base_loss.py
@abstractmethod
def compute_loss(
    self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
    """
    Compute the contrastive divergence loss from positive and negative samples.

    This method defines how the loss is calculated given real samples (positive phase)
    and samples from the model (negative phase). Typical implementations compute
    the difference between mean energies of positive and negative samples.

    Args:
        x: Real data samples (positive samples).
        pred_x: Generated negative samples.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        torch.Tensor: The contrastive divergence loss
    """
    pass

to

to(device: Union[str, device]) -> BaseContrastiveDivergence

Move loss to specified device.

Parameters:

Name Type Description Default
device Union[str, device]

Target device for computations.

required

Returns:

Type Description
BaseContrastiveDivergence

The loss function instance moved to the specified device.

Source code in torchebm/core/base_loss.py
def to(self, device: Union[str, torch.device]) -> "BaseContrastiveDivergence":
    """
    Move loss to specified device.

    Args:
        device: Target device for computations.

    Returns:
        The loss function instance moved to the specified device.
    """
    self.device = device
    return self