Skip to content

BaseTrainer

Methods and Attributes

Base class for training energy-based models.

This class provides a generic interface for training EBMs, supporting various training methods and mixed precision training.

Parameters:

Name Type Description Default
energy_function BaseEnergyFunction

Energy function to train

required
optimizer Optimizer

PyTorch optimizer to use

required
loss_fn BaseLoss

Loss function for training

required
device Optional[Union[str, device]]

Device to run training on

None
dtype dtype

Data type for computations

float32
use_mixed_precision bool

Whether to use mixed precision training

False
callbacks Optional[List[Callable]]

List of callback functions for training events

None

Methods:

Name Description
train_step

Perform a single training step

train_epoch

Train for a full epoch

train

Train for multiple epochs

validate

Validate the model

save_checkpoint

Save model checkpoint

load_checkpoint

Load model from checkpoint

Source code in torchebm/core/base_trainer.py
class BaseTrainer:
    """
    Base class for training energy-based models.

    This class provides a generic interface for training EBMs, supporting various
    training methods and mixed precision training.

    Args:
        energy_function: Energy function to train
        optimizer: PyTorch optimizer to use
        loss_fn: Loss function for training
        device: Device to run training on
        dtype: Data type for computations
        use_mixed_precision: Whether to use mixed precision training
        callbacks: List of callback functions for training events

    Methods:
        train_step: Perform a single training step
        train_epoch: Train for a full epoch
        train: Train for multiple epochs
        validate: Validate the model
        save_checkpoint: Save model checkpoint
        load_checkpoint: Load model from checkpoint
    """

    def __init__(
        self,
        energy_function: BaseEnergyFunction,
        optimizer: torch.optim.Optimizer,
        loss_fn: BaseLoss,
        device: Optional[Union[str, torch.device]] = None,
        dtype: torch.dtype = torch.float32,
        use_mixed_precision: bool = False,
        callbacks: Optional[List[Callable]] = None,
    ):
        self.energy_function = energy_function
        self.optimizer = optimizer
        self.loss_fn = loss_fn

        # Set up device
        if isinstance(device, str):
            device = torch.device(device)
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Set up dtype and mixed precision
        self.dtype = dtype
        self.use_mixed_precision = use_mixed_precision

        # Initialize callbacks
        self.callbacks = callbacks or []

        # Configure mixed precision
        if self.use_mixed_precision:
            try:
                from torch.cuda.amp import autocast, GradScaler
                self.autocast_available = True
                self.grad_scaler = GradScaler()

                # Ensure device is CUDA for mixed precision
                if not self.device.type.startswith('cuda'):
                    warnings.warn(
                        f"Mixed precision requested but device is {self.device}. "
                        f"Mixed precision requires CUDA. Falling back to full precision.",
                        UserWarning,
                    )
                    self.use_mixed_precision = False
                    self.autocast_available = False
            except ImportError:
                warnings.warn(
                    "Mixed precision requested but torch.cuda.amp not available. "
                    "Falling back to full precision. Requires PyTorch 1.6+.",
                    UserWarning,
                )
                self.use_mixed_precision = False
                self.autocast_available = False
        else:
            self.autocast_available = False

        # Move model and loss function to appropriate device/dtype
        self.energy_function = self.energy_function.to(device=self.device, dtype=self.dtype)

        # Propagate mixed precision settings to components
        if hasattr(self.loss_fn, 'use_mixed_precision'):
            self.loss_fn.use_mixed_precision = self.use_mixed_precision
        if hasattr(self.energy_function, 'use_mixed_precision'):
            self.energy_function.use_mixed_precision = self.use_mixed_precision

        # Move loss function to appropriate device
        if hasattr(self.loss_fn, 'to'):
            self.loss_fn = self.loss_fn.to(device=self.device, dtype=self.dtype)

        # Create metrics dictionary for tracking
        self.metrics: Dict[str, Any] = {'loss': []}

    def train_step(self, batch: torch.Tensor) -> Dict[str, Any]:
        """
        Perform a single training step.

        Args:
            batch: Batch of training data

        Returns:
            Dictionary containing metrics from this step
        """
        # Ensure batch is on the correct device and dtype
        batch = batch.to(device=self.device, dtype=self.dtype)

        # Zero gradients
        self.optimizer.zero_grad()

        # Forward pass with mixed precision if enabled
        if self.use_mixed_precision and self.autocast_available:
            from torch.cuda.amp import autocast
            with autocast():
                loss = self.loss_fn(batch)

            # Backward pass with gradient scaling
            self.grad_scaler.scale(loss).backward()
            self.grad_scaler.step(self.optimizer)
            self.grad_scaler.update()
        else:
            # Standard training step
            loss = self.loss_fn(batch)
            loss.backward()
            self.optimizer.step()

        # Return metrics
        return {'loss': loss.item()}

    def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
        """
        Train for one epoch.

        Args:
            dataloader: DataLoader containing training data

        Returns:
            Dictionary with average metrics for the epoch
        """
        # Set model to training mode
        self.energy_function.train()

        # Initialize metrics for this epoch
        epoch_metrics: Dict[str, List[float]] = {'loss': []}

        # Iterate through batches
        for batch in dataloader:
            # Call any batch start callbacks
            for callback in self.callbacks:
                if hasattr(callback, 'on_batch_start'):
                    callback.on_batch_start(self, batch)

            # Perform training step
            step_metrics = self.train_step(batch)

            # Update epoch metrics
            for key, value in step_metrics.items():
                if key not in epoch_metrics:
                    epoch_metrics[key] = []
                epoch_metrics[key].append(value)

            # Call any batch end callbacks
            for callback in self.callbacks:
                if hasattr(callback, 'on_batch_end'):
                    callback.on_batch_end(self, batch, step_metrics)

        # Calculate average metrics
        avg_metrics = {key: sum(values) / len(values) for key, values in epoch_metrics.items()}

        return avg_metrics

    def train(self, dataloader: DataLoader, num_epochs: int, validate_fn: Optional[Callable] = None) -> Dict[str, List[float]]:
        """
        Train the model for multiple epochs.

        Args:
            dataloader: DataLoader containing training data
            num_epochs: Number of epochs to train for
            validate_fn: Optional function for validation after each epoch

        Returns:
            Dictionary with metrics over all epochs
        """
        # Initialize training history
        history: Dict[str, List[float]] = {'loss': []}

        # Call any training start callbacks
        for callback in self.callbacks:
            if hasattr(callback, 'on_train_start'):
                callback.on_train_start(self)

        # Train for specified number of epochs
        for epoch in range(num_epochs):
            # Call any epoch start callbacks
            for callback in self.callbacks:
                if hasattr(callback, 'on_epoch_start'):
                    callback.on_epoch_start(self, epoch)

            # Train for one epoch
            epoch_metrics = self.train_epoch(dataloader)

            # Update training history
            for key, value in epoch_metrics.items():
                if key not in history:
                    history[key] = []
                history[key].append(value)

            # Print progress
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_metrics['loss']:.6f}")

            # Validate if function provided
            if validate_fn is not None:
                val_metrics = validate_fn(self.energy_function)
                print(f"Validation: {val_metrics}")

                # Update validation metrics in history
                for key, value in val_metrics.items():
                    val_key = f"val_{key}"
                    if val_key not in history:
                        history[val_key] = []
                    history[val_key].append(value)

            # Call any epoch end callbacks
            for callback in self.callbacks:
                if hasattr(callback, 'on_epoch_end'):
                    callback.on_epoch_end(self, epoch, epoch_metrics)

        # Call any training end callbacks
        for callback in self.callbacks:
            if hasattr(callback, 'on_train_end'):
                callback.on_train_end(self, history)

        return history

    def save_checkpoint(self, path: str) -> None:
        """
        Save a checkpoint of the current training state.

        Args:
            path: Path to save the checkpoint to
        """
        checkpoint = {
            'energy_function_state_dict': self.energy_function.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'metrics': self.metrics,
        }

        if self.use_mixed_precision and hasattr(self, 'grad_scaler'):
            checkpoint['grad_scaler_state_dict'] = self.grad_scaler.state_dict()

        torch.save(checkpoint, path)

    def load_checkpoint(self, path: str) -> None:
        """
        Load a checkpoint to resume training.

        Args:
            path: Path to the checkpoint file
        """
        checkpoint = torch.load(path, map_location=self.device)

        self.energy_function.load_state_dict(checkpoint['energy_function_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        if 'metrics' in checkpoint:
            self.metrics = checkpoint['metrics']

        if self.use_mixed_precision and 'grad_scaler_state_dict' in checkpoint and hasattr(self, 'grad_scaler'):
            self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state_dict'])

optimizer instance-attribute

optimizer = optimizer

loss_fn instance-attribute

loss_fn = loss_fn

device instance-attribute

device = device or device('cuda' if is_available() else 'cpu')

dtype instance-attribute

dtype = dtype

use_mixed_precision instance-attribute

use_mixed_precision = use_mixed_precision

callbacks instance-attribute

callbacks = callbacks or []

autocast_available instance-attribute

autocast_available = True

grad_scaler instance-attribute

grad_scaler = GradScaler()

energy_function instance-attribute

energy_function = to(device=device, dtype=dtype)

metrics instance-attribute

metrics: Dict[str, Any] = {'loss': []}

train_step

train_step(batch: Tensor) -> Dict[str, Any]

Perform a single training step.

Parameters:

Name Type Description Default
batch Tensor

Batch of training data

required

Returns:

Type Description
Dict[str, Any]

Dictionary containing metrics from this step

Source code in torchebm/core/base_trainer.py
def train_step(self, batch: torch.Tensor) -> Dict[str, Any]:
    """
    Perform a single training step.

    Args:
        batch: Batch of training data

    Returns:
        Dictionary containing metrics from this step
    """
    # Ensure batch is on the correct device and dtype
    batch = batch.to(device=self.device, dtype=self.dtype)

    # Zero gradients
    self.optimizer.zero_grad()

    # Forward pass with mixed precision if enabled
    if self.use_mixed_precision and self.autocast_available:
        from torch.cuda.amp import autocast
        with autocast():
            loss = self.loss_fn(batch)

        # Backward pass with gradient scaling
        self.grad_scaler.scale(loss).backward()
        self.grad_scaler.step(self.optimizer)
        self.grad_scaler.update()
    else:
        # Standard training step
        loss = self.loss_fn(batch)
        loss.backward()
        self.optimizer.step()

    # Return metrics
    return {'loss': loss.item()}

train_epoch

train_epoch(dataloader: DataLoader) -> Dict[str, float]

Train for one epoch.

Parameters:

Name Type Description Default
dataloader DataLoader

DataLoader containing training data

required

Returns:

Type Description
Dict[str, float]

Dictionary with average metrics for the epoch

Source code in torchebm/core/base_trainer.py
def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
    """
    Train for one epoch.

    Args:
        dataloader: DataLoader containing training data

    Returns:
        Dictionary with average metrics for the epoch
    """
    # Set model to training mode
    self.energy_function.train()

    # Initialize metrics for this epoch
    epoch_metrics: Dict[str, List[float]] = {'loss': []}

    # Iterate through batches
    for batch in dataloader:
        # Call any batch start callbacks
        for callback in self.callbacks:
            if hasattr(callback, 'on_batch_start'):
                callback.on_batch_start(self, batch)

        # Perform training step
        step_metrics = self.train_step(batch)

        # Update epoch metrics
        for key, value in step_metrics.items():
            if key not in epoch_metrics:
                epoch_metrics[key] = []
            epoch_metrics[key].append(value)

        # Call any batch end callbacks
        for callback in self.callbacks:
            if hasattr(callback, 'on_batch_end'):
                callback.on_batch_end(self, batch, step_metrics)

    # Calculate average metrics
    avg_metrics = {key: sum(values) / len(values) for key, values in epoch_metrics.items()}

    return avg_metrics

train

train(dataloader: DataLoader, num_epochs: int, validate_fn: Optional[Callable] = None) -> Dict[str, List[float]]

Train the model for multiple epochs.

Parameters:

Name Type Description Default
dataloader DataLoader

DataLoader containing training data

required
num_epochs int

Number of epochs to train for

required
validate_fn Optional[Callable]

Optional function for validation after each epoch

None

Returns:

Type Description
Dict[str, List[float]]

Dictionary with metrics over all epochs

Source code in torchebm/core/base_trainer.py
def train(self, dataloader: DataLoader, num_epochs: int, validate_fn: Optional[Callable] = None) -> Dict[str, List[float]]:
    """
    Train the model for multiple epochs.

    Args:
        dataloader: DataLoader containing training data
        num_epochs: Number of epochs to train for
        validate_fn: Optional function for validation after each epoch

    Returns:
        Dictionary with metrics over all epochs
    """
    # Initialize training history
    history: Dict[str, List[float]] = {'loss': []}

    # Call any training start callbacks
    for callback in self.callbacks:
        if hasattr(callback, 'on_train_start'):
            callback.on_train_start(self)

    # Train for specified number of epochs
    for epoch in range(num_epochs):
        # Call any epoch start callbacks
        for callback in self.callbacks:
            if hasattr(callback, 'on_epoch_start'):
                callback.on_epoch_start(self, epoch)

        # Train for one epoch
        epoch_metrics = self.train_epoch(dataloader)

        # Update training history
        for key, value in epoch_metrics.items():
            if key not in history:
                history[key] = []
            history[key].append(value)

        # Print progress
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_metrics['loss']:.6f}")

        # Validate if function provided
        if validate_fn is not None:
            val_metrics = validate_fn(self.energy_function)
            print(f"Validation: {val_metrics}")

            # Update validation metrics in history
            for key, value in val_metrics.items():
                val_key = f"val_{key}"
                if val_key not in history:
                    history[val_key] = []
                history[val_key].append(value)

        # Call any epoch end callbacks
        for callback in self.callbacks:
            if hasattr(callback, 'on_epoch_end'):
                callback.on_epoch_end(self, epoch, epoch_metrics)

    # Call any training end callbacks
    for callback in self.callbacks:
        if hasattr(callback, 'on_train_end'):
            callback.on_train_end(self, history)

    return history

save_checkpoint

save_checkpoint(path: str) -> None

Save a checkpoint of the current training state.

Parameters:

Name Type Description Default
path str

Path to save the checkpoint to

required
Source code in torchebm/core/base_trainer.py
def save_checkpoint(self, path: str) -> None:
    """
    Save a checkpoint of the current training state.

    Args:
        path: Path to save the checkpoint to
    """
    checkpoint = {
        'energy_function_state_dict': self.energy_function.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'metrics': self.metrics,
    }

    if self.use_mixed_precision and hasattr(self, 'grad_scaler'):
        checkpoint['grad_scaler_state_dict'] = self.grad_scaler.state_dict()

    torch.save(checkpoint, path)

load_checkpoint

load_checkpoint(path: str) -> None

Load a checkpoint to resume training.

Parameters:

Name Type Description Default
path str

Path to the checkpoint file

required
Source code in torchebm/core/base_trainer.py
def load_checkpoint(self, path: str) -> None:
    """
    Load a checkpoint to resume training.

    Args:
        path: Path to the checkpoint file
    """
    checkpoint = torch.load(path, map_location=self.device)

    self.energy_function.load_state_dict(checkpoint['energy_function_state_dict'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if 'metrics' in checkpoint:
        self.metrics = checkpoint['metrics']

    if self.use_mixed_precision and 'grad_scaler_state_dict' in checkpoint and hasattr(self, 'grad_scaler'):
        self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state_dict'])