Skip to content

ContrastiveDivergenceTrainer

Methods and Attributes

Bases: BaseTrainer

Specialized trainer for contrastive divergence training of EBMs.

Parameters:

Name Type Description Default
energy_function BaseEnergyFunction

Energy function to train

required
sampler BaseSampler

MCMC sampler for generating negative samples

required
optimizer Optional[Optimizer]

PyTorch optimizer

None
learning_rate float

Learning rate (if optimizer not provided)

0.01
k_steps int

Number of MCMC steps for generating samples

10
persistent bool

Whether to use persistent contrastive divergence (PCD)

False
buffer_size int

Replay buffer size for PCD

1000
device Optional[Union[str, device]]

Device to run training on

None
dtype dtype

Data type for computations

float32
use_mixed_precision bool

Whether to use mixed precision training

False
Source code in torchebm/core/base_trainer.py
class ContrastiveDivergenceTrainer(BaseTrainer):
    """
    Specialized trainer for contrastive divergence training of EBMs.

    Args:
        energy_function: Energy function to train
        sampler: MCMC sampler for generating negative samples
        optimizer: PyTorch optimizer
        learning_rate: Learning rate (if optimizer not provided)
        k_steps: Number of MCMC steps for generating samples
        persistent: Whether to use persistent contrastive divergence (PCD)
        buffer_size: Replay buffer size for PCD
        device: Device to run training on
        dtype: Data type for computations
        use_mixed_precision: Whether to use mixed precision training
    """

    def __init__(
        self,
        energy_function: BaseEnergyFunction,
        sampler: BaseSampler,
        optimizer: Optional[torch.optim.Optimizer] = None,
        learning_rate: float = 0.01,
        k_steps: int = 10,
        persistent: bool = False,
        buffer_size: int = 1000,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        use_mixed_precision: bool = False,
    ):
        # Create optimizer if not provided
        if optimizer is None:
            optimizer = torch.optim.Adam(energy_function.parameters(), lr=learning_rate)

        # Import here to avoid circular import
        from torchebm.losses.contrastive_divergence import ContrastiveDivergence

        # Create loss function
        loss_fn = ContrastiveDivergence(
            energy_function=energy_function,
            sampler=sampler,
            k_steps=k_steps,
            persistent=persistent,
            buffer_size=buffer_size,
            dtype=dtype,
            device=device,
            use_mixed_precision=use_mixed_precision,
        )

        # Initialize base trainer
        super().__init__(
            energy_function=energy_function,
            optimizer=optimizer,
            loss_fn=loss_fn,
            device=device,
            dtype=dtype,
            use_mixed_precision=use_mixed_precision,
        )

        self.sampler = sampler

    def train_step(self, batch: torch.Tensor) -> Dict[str, Any]:
        """
        Perform a single contrastive divergence training step.

        Args:
            batch: Batch of real data samples

        Returns:
            Dictionary containing metrics from this step
        """
        # Ensure batch is on the correct device and dtype
        batch = batch.to(device=self.device, dtype=self.dtype)

        # Zero gradients
        self.optimizer.zero_grad()

        # Forward pass with mixed precision if enabled
        if self.use_mixed_precision and self.autocast_available:
            from torch.cuda.amp import autocast
            with autocast():
                # ContrastiveDivergence returns (loss, neg_samples)
                loss, neg_samples = self.loss_fn(batch)

            # Backward pass with gradient scaling
            self.grad_scaler.scale(loss).backward()
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            # Standard training step
            loss, neg_samples = self.loss_fn(batch)
            loss.backward()
            self.optimizer.step()

        # Return metrics
        return {
            'loss': loss.item(),
            'pos_energy': self.energy_function(batch).mean().item(),
            'neg_energy': self.energy_function(neg_samples).mean().item(),
        }

sampler instance-attribute

sampler = sampler

train_step

train_step(batch: Tensor) -> Dict[str, Any]

Perform a single contrastive divergence training step.

Parameters:

Name Type Description Default
batch Tensor

Batch of real data samples

required

Returns:

Type Description
Dict[str, Any]

Dictionary containing metrics from this step

Source code in torchebm/core/base_trainer.py
def train_step(self, batch: torch.Tensor) -> Dict[str, Any]:
    """
    Perform a single contrastive divergence training step.

    Args:
        batch: Batch of real data samples

    Returns:
        Dictionary containing metrics from this step
    """
    # Ensure batch is on the correct device and dtype
    batch = batch.to(device=self.device, dtype=self.dtype)

    # Zero gradients
    self.optimizer.zero_grad()

    # Forward pass with mixed precision if enabled
    if self.use_mixed_precision and self.autocast_available:
        from torch.cuda.amp import autocast
        with autocast():
            # ContrastiveDivergence returns (loss, neg_samples)
            loss, neg_samples = self.loss_fn(batch)

        # Backward pass with gradient scaling
        self.grad_scaler.scale(loss).backward()
        self.grad_scaler.step(self.optimizer)
        self.grad_scaler.update()
    else:
        # Standard training step
        loss, neg_samples = self.loss_fn(batch)
        loss.backward()
        self.optimizer.step()

    # Return metrics
    return {
        'loss': loss.item(),
        'pos_energy': self.energy_function(batch).mean().item(),
        'neg_energy': self.energy_function(neg_samples).mean().item(),
    }