Skip to content

BaseContrastiveDivergence

Methods and Attributes

Bases: BaseLoss

Abstract base class for Contrastive Divergence (CD) based loss functions.

Parameters:

Name Type Description Default
model BaseModel

The energy-based model to be trained.

required
sampler BaseSampler

The MCMC sampler for generating negative samples.

required
k_steps int

The number of MCMC steps to perform for each update.

1
persistent bool

If True, uses a replay buffer for Persistent CD (PCD).

False
buffer_size int

The size of the replay buffer for PCD.

100
new_sample_ratio float

The ratio of new random samples to introduce into the MCMC chain.

0.0
init_steps int

The number of MCMC steps to run when initializing new chain elements.

0
dtype dtype

Data type for computations.

float32
device Optional[Union[str, device]]

Device for computations.

None
use_mixed_precision bool

Whether to use mixed precision training.

False
clip_value Optional[float]

Optional value to clamp the loss.

None
Source code in torchebm/core/base_loss.py
class BaseContrastiveDivergence(BaseLoss):
    """
    Abstract base class for Contrastive Divergence (CD) based loss functions.

    Args:
        model (BaseModel): The energy-based model to be trained.
        sampler (BaseSampler): The MCMC sampler for generating negative samples.
        k_steps (int): The number of MCMC steps to perform for each update.
        persistent (bool): If `True`, uses a replay buffer for Persistent CD (PCD).
        buffer_size (int): The size of the replay buffer for PCD.
        new_sample_ratio (float): The ratio of new random samples to introduce into the MCMC chain.
        init_steps (int): The number of MCMC steps to run when initializing new chain elements.
        dtype (torch.dtype): Data type for computations.
        device (Optional[Union[str, torch.device]]): Device for computations.
        use_mixed_precision (bool): Whether to use mixed precision training.
        clip_value (Optional[float]): Optional value to clamp the loss.
    """

    def __init__(
        self,
        model: BaseModel,
        sampler: BaseSampler,
        k_steps: int = 1,
        persistent: bool = False,
        buffer_size: int = 100,
        new_sample_ratio: float = 0.0,
        init_steps: int = 0,
        dtype: torch.dtype = torch.float32,
        device: Optional[Union[str, torch.device]] = None,
        use_mixed_precision: bool = False,
        clip_value: Optional[float] = None,
        *args,
        **kwargs,
    ):
        super().__init__(
            dtype=dtype,
            device=device,
            use_mixed_precision=use_mixed_precision,
            clip_value=clip_value,
            *args,
            **kwargs,
        )
        self.model = model
        self.sampler = sampler
        self.k_steps = k_steps
        self.persistent = persistent
        self.buffer_size = buffer_size
        self.new_sample_ratio = new_sample_ratio
        self.init_steps = init_steps

        self.model = self.model.to(device=self.device)
        if hasattr(self.sampler, "to") and callable(getattr(self.sampler, "to")):
            self.sampler = self.sampler.to(device=self.device)

        self.register_buffer("replay_buffer", None)
        self.register_buffer(
            "buffer_ptr", torch.tensor(0, dtype=torch.long, device=self.device)
        )
        self.buffer_initialized = False

    def initialize_buffer(
        self,
        data_shape_no_batch: Tuple[int, ...],
        buffer_chunk_size: int = 1024,
        init_noise_scale: float = 0.01,
    ) -> torch.Tensor:
        """
        Initializes the replay buffer with random noise for PCD.

        Args:
            data_shape_no_batch (Tuple[int, ...]): The shape of the data excluding the batch dimension.
            buffer_chunk_size (int): The size of chunks to process during initialization.
            init_noise_scale (float): The scale of the initial noise.

        Returns:
            torch.Tensor: The initialized replay buffer.
        """
        if not self.persistent or self.buffer_initialized:
            return

        if self.buffer_size <= 0:
            raise ValueError(
                f"Replay buffer size must be positive, got {self.buffer_size}"
            )

        buffer_shape = (
            self.buffer_size,
        ) + data_shape_no_batch  # shape: [buffer_size, *data_shape]
        print(f"Initializing replay buffer with shape {buffer_shape}...")

        self.replay_buffer = (
            torch.randn(buffer_shape, dtype=self.dtype, device=self.device)
            * init_noise_scale
        )

        if self.init_steps > 0:
            print(f"Running {self.init_steps} MCMC steps to populate buffer...")
            with torch.no_grad():
                chunk_size = min(self.buffer_size, buffer_chunk_size)
                for i in range(0, self.buffer_size, chunk_size):
                    end = min(i + chunk_size, self.buffer_size)
                    current_chunk = self.replay_buffer[i:end].clone()
                    try:
                        with self.autocast_context():
                            updated_chunk = self.sampler.sample(
                                x=current_chunk, n_steps=self.init_steps
                            ).detach()

                        if updated_chunk.shape == current_chunk.shape:
                            self.replay_buffer[i:end] = updated_chunk
                        else:
                            warnings.warn(
                                f"Sampler output shape mismatch during buffer init. Expected {current_chunk.shape}, got {updated_chunk.shape}. Skipping update for chunk {i}-{end}."
                            )
                    except Exception as e:
                        warnings.warn(
                            f"Error during buffer initialization sampling for chunk {i}-{end}: {e}. Keeping noise for this chunk."
                        )

        self.buffer_ptr.zero_()
        self.buffer_initialized = True
        print(f"Replay buffer initialized.")

        return self.replay_buffer

    def get_start_points(self, x: torch.Tensor) -> torch.Tensor:
        """
        Gets the starting points for the MCMC sampler.

        For standard CD, this is the input data. For PCD, it's samples from the replay buffer.

        Args:
            x (torch.Tensor): The input data batch.

        Returns:
            torch.Tensor: The tensor of starting points for the sampler.
        """
        x = x.to(device=self.device, dtype=self.dtype)

        batch_size = x.shape[0]
        data_shape_no_batch = x.shape[1:]

        if self.persistent:
            if not self.buffer_initialized:
                self.initialize_buffer(data_shape_no_batch)
                if not self.buffer_initialized:
                    raise RuntimeError("Buffer initialization failed.")

            if self.buffer_size < batch_size:
                warnings.warn(
                    f"Buffer size ({self.buffer_size}) is smaller than batch size ({batch_size}). Sampling with replacement.",
                    UserWarning,
                )
                indices = torch.randint(
                    0, self.buffer_size, (batch_size,), device=self.device
                )
            else:
                # stratified sampling for better buffer coverage
                stride = self.buffer_size // batch_size
                base_indices = torch.arange(0, batch_size, device=self.device) * stride
                offset = torch.randint(0, stride, (batch_size,), device=self.device)
                indices = (base_indices + offset) % self.buffer_size

            start_points = self.replay_buffer[indices].detach().clone()

            # add some noise for exploration
            if self.new_sample_ratio > 0.0:
                n_new = max(1, int(batch_size * self.new_sample_ratio))
                noise_indices = torch.randperm(batch_size, device=self.device)[:n_new]
                noise_scale = 0.01
                start_points[noise_indices] = (
                    start_points[noise_indices]
                    + torch.randn_like(
                        start_points[noise_indices],
                        device=self.device,
                        dtype=self.dtype,
                    )
                    * noise_scale
                )
        else:
            # standard CD-k uses data as starting points
            start_points = x.detach().clone()

        return start_points

    def get_negative_samples(self, x, batch_size, data_shape) -> torch.Tensor:
        """
        Gets negative samples using the replay buffer strategy.

        Args:
            x: (Unused) The input data tensor.
            batch_size (int): The number of samples to generate.
            data_shape (Tuple[int, ...]): The shape of the data samples (excluding batch size).

        Returns:
            torch.Tensor: Negative samples.
        """
        if not self.persistent or not self.buffer_initialized:
            # For non-persistent CD, just return random noise
            return torch.randn(
                (batch_size,) + data_shape, dtype=self.dtype, device=self.device
            )

        n_new = max(1, int(batch_size * self.new_sample_ratio))
        n_old = batch_size - n_new

        all_samples = torch.empty(
            (batch_size,) + data_shape, dtype=self.dtype, device=self.device
        )

        # new random samples
        if n_new > 0:
            all_samples[:n_new] = torch.randn(
                (n_new,) + data_shape, dtype=self.dtype, device=self.device
            )

        # samples from buffer
        if n_old > 0:

            indices = torch.randint(0, self.buffer_size, (n_old,), device=self.device)
            all_samples[n_new:] = self.replay_buffer[indices]

        return all_samples

    def update_buffer(self, samples: torch.Tensor) -> None:
        """
        Updates the replay buffer with new samples using a FIFO strategy.

        Args:
            samples (torch.Tensor): New samples to add to the buffer.
        """
        if not self.persistent or not self.buffer_initialized:
            return

        # Ensure samples are on the correct device and dtype
        samples = samples.to(device=self.device, dtype=self.dtype).detach()

        batch_size = samples.shape[0]

        # FIFO strategy
        ptr = int(self.buffer_ptr.item())

        if batch_size >= self.buffer_size:
            # batch larger than buffer, use latest samples
            self.replay_buffer[:] = samples[-self.buffer_size :].detach()
            self.buffer_ptr[...] = 0
        else:
            # handle buffer wraparound
            end_ptr = (ptr + batch_size) % self.buffer_size

            if end_ptr > ptr:
                self.replay_buffer[ptr:end_ptr] = samples.detach()
            else:
                # wraparound case - split update
                first_part = self.buffer_size - ptr
                self.replay_buffer[ptr:] = samples[:first_part].detach()
                self.replay_buffer[:end_ptr] = samples[first_part:].detach()

            self.buffer_ptr[...] = end_ptr

    @abstractmethod
    def forward(
        self, x: torch.Tensor, *args, **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the CD loss given real data samples.

        Args:
            x (torch.Tensor): Real data samples (positive samples).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                - The contrastive divergence loss.
                - The generated negative samples.
        """
        pass

    @abstractmethod
    def compute_loss(
        self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
    ) -> torch.Tensor:
        """
        Computes the contrastive divergence loss from positive and negative samples.

        Args:
            x (torch.Tensor): Real data samples (positive samples).
            pred_x (torch.Tensor): Generated negative samples.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The contrastive divergence loss.
        """
        pass

    def __repr__(self):
        """Return a string representation of the loss function."""
        return f"{self.__class__.__name__}(model={self.model}, sampler={self.sampler})"

    def __str__(self):
        """Return a string representation of the loss function."""
        return self.__repr__()

sampler instance-attribute

sampler = sampler

k_steps instance-attribute

k_steps = k_steps

persistent instance-attribute

persistent = persistent

buffer_size instance-attribute

buffer_size = buffer_size

new_sample_ratio instance-attribute

new_sample_ratio = new_sample_ratio

init_steps instance-attribute

init_steps = init_steps

model instance-attribute

model = to(device=device)

buffer_initialized instance-attribute

buffer_initialized = False

initialize_buffer

initialize_buffer(data_shape_no_batch: Tuple[int, ...], buffer_chunk_size: int = 1024, init_noise_scale: float = 0.01) -> torch.Tensor

Initializes the replay buffer with random noise for PCD.

Parameters:

Name Type Description Default
data_shape_no_batch Tuple[int, ...]

The shape of the data excluding the batch dimension.

required
buffer_chunk_size int

The size of chunks to process during initialization.

1024
init_noise_scale float

The scale of the initial noise.

0.01

Returns:

Type Description
Tensor

torch.Tensor: The initialized replay buffer.

Source code in torchebm/core/base_loss.py
def initialize_buffer(
    self,
    data_shape_no_batch: Tuple[int, ...],
    buffer_chunk_size: int = 1024,
    init_noise_scale: float = 0.01,
) -> torch.Tensor:
    """
    Initializes the replay buffer with random noise for PCD.

    Args:
        data_shape_no_batch (Tuple[int, ...]): The shape of the data excluding the batch dimension.
        buffer_chunk_size (int): The size of chunks to process during initialization.
        init_noise_scale (float): The scale of the initial noise.

    Returns:
        torch.Tensor: The initialized replay buffer.
    """
    if not self.persistent or self.buffer_initialized:
        return

    if self.buffer_size <= 0:
        raise ValueError(
            f"Replay buffer size must be positive, got {self.buffer_size}"
        )

    buffer_shape = (
        self.buffer_size,
    ) + data_shape_no_batch  # shape: [buffer_size, *data_shape]
    print(f"Initializing replay buffer with shape {buffer_shape}...")

    self.replay_buffer = (
        torch.randn(buffer_shape, dtype=self.dtype, device=self.device)
        * init_noise_scale
    )

    if self.init_steps > 0:
        print(f"Running {self.init_steps} MCMC steps to populate buffer...")
        with torch.no_grad():
            chunk_size = min(self.buffer_size, buffer_chunk_size)
            for i in range(0, self.buffer_size, chunk_size):
                end = min(i + chunk_size, self.buffer_size)
                current_chunk = self.replay_buffer[i:end].clone()
                try:
                    with self.autocast_context():
                        updated_chunk = self.sampler.sample(
                            x=current_chunk, n_steps=self.init_steps
                        ).detach()

                    if updated_chunk.shape == current_chunk.shape:
                        self.replay_buffer[i:end] = updated_chunk
                    else:
                        warnings.warn(
                            f"Sampler output shape mismatch during buffer init. Expected {current_chunk.shape}, got {updated_chunk.shape}. Skipping update for chunk {i}-{end}."
                        )
                except Exception as e:
                    warnings.warn(
                        f"Error during buffer initialization sampling for chunk {i}-{end}: {e}. Keeping noise for this chunk."
                    )

    self.buffer_ptr.zero_()
    self.buffer_initialized = True
    print(f"Replay buffer initialized.")

    return self.replay_buffer

get_start_points

get_start_points(x: Tensor) -> torch.Tensor

Gets the starting points for the MCMC sampler.

For standard CD, this is the input data. For PCD, it's samples from the replay buffer.

Parameters:

Name Type Description Default
x Tensor

The input data batch.

required

Returns:

Type Description
Tensor

torch.Tensor: The tensor of starting points for the sampler.

Source code in torchebm/core/base_loss.py
def get_start_points(self, x: torch.Tensor) -> torch.Tensor:
    """
    Gets the starting points for the MCMC sampler.

    For standard CD, this is the input data. For PCD, it's samples from the replay buffer.

    Args:
        x (torch.Tensor): The input data batch.

    Returns:
        torch.Tensor: The tensor of starting points for the sampler.
    """
    x = x.to(device=self.device, dtype=self.dtype)

    batch_size = x.shape[0]
    data_shape_no_batch = x.shape[1:]

    if self.persistent:
        if not self.buffer_initialized:
            self.initialize_buffer(data_shape_no_batch)
            if not self.buffer_initialized:
                raise RuntimeError("Buffer initialization failed.")

        if self.buffer_size < batch_size:
            warnings.warn(
                f"Buffer size ({self.buffer_size}) is smaller than batch size ({batch_size}). Sampling with replacement.",
                UserWarning,
            )
            indices = torch.randint(
                0, self.buffer_size, (batch_size,), device=self.device
            )
        else:
            # stratified sampling for better buffer coverage
            stride = self.buffer_size // batch_size
            base_indices = torch.arange(0, batch_size, device=self.device) * stride
            offset = torch.randint(0, stride, (batch_size,), device=self.device)
            indices = (base_indices + offset) % self.buffer_size

        start_points = self.replay_buffer[indices].detach().clone()

        # add some noise for exploration
        if self.new_sample_ratio > 0.0:
            n_new = max(1, int(batch_size * self.new_sample_ratio))
            noise_indices = torch.randperm(batch_size, device=self.device)[:n_new]
            noise_scale = 0.01
            start_points[noise_indices] = (
                start_points[noise_indices]
                + torch.randn_like(
                    start_points[noise_indices],
                    device=self.device,
                    dtype=self.dtype,
                )
                * noise_scale
            )
    else:
        # standard CD-k uses data as starting points
        start_points = x.detach().clone()

    return start_points

get_negative_samples

get_negative_samples(x, batch_size, data_shape) -> torch.Tensor

Gets negative samples using the replay buffer strategy.

Parameters:

Name Type Description Default
x

(Unused) The input data tensor.

required
batch_size int

The number of samples to generate.

required
data_shape Tuple[int, ...]

The shape of the data samples (excluding batch size).

required

Returns:

Type Description
Tensor

torch.Tensor: Negative samples.

Source code in torchebm/core/base_loss.py
def get_negative_samples(self, x, batch_size, data_shape) -> torch.Tensor:
    """
    Gets negative samples using the replay buffer strategy.

    Args:
        x: (Unused) The input data tensor.
        batch_size (int): The number of samples to generate.
        data_shape (Tuple[int, ...]): The shape of the data samples (excluding batch size).

    Returns:
        torch.Tensor: Negative samples.
    """
    if not self.persistent or not self.buffer_initialized:
        # For non-persistent CD, just return random noise
        return torch.randn(
            (batch_size,) + data_shape, dtype=self.dtype, device=self.device
        )

    n_new = max(1, int(batch_size * self.new_sample_ratio))
    n_old = batch_size - n_new

    all_samples = torch.empty(
        (batch_size,) + data_shape, dtype=self.dtype, device=self.device
    )

    # new random samples
    if n_new > 0:
        all_samples[:n_new] = torch.randn(
            (n_new,) + data_shape, dtype=self.dtype, device=self.device
        )

    # samples from buffer
    if n_old > 0:

        indices = torch.randint(0, self.buffer_size, (n_old,), device=self.device)
        all_samples[n_new:] = self.replay_buffer[indices]

    return all_samples

update_buffer

update_buffer(samples: Tensor) -> None

Updates the replay buffer with new samples using a FIFO strategy.

Parameters:

Name Type Description Default
samples Tensor

New samples to add to the buffer.

required
Source code in torchebm/core/base_loss.py
def update_buffer(self, samples: torch.Tensor) -> None:
    """
    Updates the replay buffer with new samples using a FIFO strategy.

    Args:
        samples (torch.Tensor): New samples to add to the buffer.
    """
    if not self.persistent or not self.buffer_initialized:
        return

    # Ensure samples are on the correct device and dtype
    samples = samples.to(device=self.device, dtype=self.dtype).detach()

    batch_size = samples.shape[0]

    # FIFO strategy
    ptr = int(self.buffer_ptr.item())

    if batch_size >= self.buffer_size:
        # batch larger than buffer, use latest samples
        self.replay_buffer[:] = samples[-self.buffer_size :].detach()
        self.buffer_ptr[...] = 0
    else:
        # handle buffer wraparound
        end_ptr = (ptr + batch_size) % self.buffer_size

        if end_ptr > ptr:
            self.replay_buffer[ptr:end_ptr] = samples.detach()
        else:
            # wraparound case - split update
            first_part = self.buffer_size - ptr
            self.replay_buffer[ptr:] = samples[:first_part].detach()
            self.replay_buffer[:end_ptr] = samples[first_part:].detach()

        self.buffer_ptr[...] = end_ptr

forward abstractmethod

forward(x: Tensor, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]

Computes the CD loss given real data samples.

Parameters:

Name Type Description Default
x Tensor

Real data samples (positive samples).

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple[torch.Tensor, torch.Tensor]: - The contrastive divergence loss. - The generated negative samples.

Source code in torchebm/core/base_loss.py
@abstractmethod
def forward(
    self, x: torch.Tensor, *args, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes the CD loss given real data samples.

    Args:
        x (torch.Tensor): Real data samples (positive samples).

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - The contrastive divergence loss.
            - The generated negative samples.
    """
    pass

compute_loss abstractmethod

compute_loss(x: Tensor, pred_x: Tensor, *args, **kwargs) -> torch.Tensor

Computes the contrastive divergence loss from positive and negative samples.

Parameters:

Name Type Description Default
x Tensor

Real data samples (positive samples).

required
pred_x Tensor

Generated negative samples.

required
*args

Additional positional arguments.

()
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

torch.Tensor: The contrastive divergence loss.

Source code in torchebm/core/base_loss.py
@abstractmethod
def compute_loss(
    self, x: torch.Tensor, pred_x: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
    """
    Computes the contrastive divergence loss from positive and negative samples.

    Args:
        x (torch.Tensor): Real data samples (positive samples).
        pred_x (torch.Tensor): Generated negative samples.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        torch.Tensor: The contrastive divergence loss.
    """
    pass