Skip to content

Models Implementation

Implementation Details

This guide explains the implementation of neural network models in TorchEBM, including architecture designs, training workflows, and integration with energy functions.

Base Model Architecture

The BaseModel class provides the foundation for all neural networks in TorchEBM:

import torch
import torch.nn as nn
from typing import Tuple, List, Dict, Any, Optional, Union

class BaseModel(nn.Module):
    """Base class for all neural network models."""

    def __init__(self, input_dim: int, hidden_dims: List[int], activation: Optional[nn.Module] = None):
        """Initialize base model.

        Args:
            input_dim: Input dimension
            hidden_dims: List of hidden dimensions
            activation: Activation function
        """
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.activation = activation or nn.ReLU()

        # Build network architecture
        self._build_network()

    def _build_network(self):
        """Build the neural network architecture."""
        layers = []

        # Input layer
        prev_dim = self.input_dim

        # Hidden layers
        for dim in self.hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(self.activation)
            prev_dim = dim

        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network.

        Args:
            x: Input tensor of shape (batch_size, input_dim)

        Returns:
            Output tensor
        """
        return self.network(x)

MLP Energy Model

class MLPEnergyModel(BaseModel):
    """Multi-layer perceptron energy model."""

    def __init__(
        self, 
        input_dim: int, 
        hidden_dims: List[int], 
        activation: Optional[nn.Module] = None,
        use_spectral_norm: bool = False
    ):
        """Initialize MLP energy model.

        Args:
            input_dim: Input dimension
            hidden_dims: List of hidden dimensions
            activation: Activation function
            use_spectral_norm: Whether to use spectral normalization
        """
        super().__init__(input_dim, hidden_dims, activation)
        self.output_layer = nn.Linear(hidden_dims[-1], 1)

        # Apply spectral normalization if requested
        if use_spectral_norm:
            self._apply_spectral_norm()

    def _apply_spectral_norm(self):
        """Apply spectral normalization to all linear layers."""
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                setattr(
                    self, 
                    name, 
                    nn.utils.spectral_norm(module)
                )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass to compute energy.

        Args:
            x: Input tensor of shape (batch_size, input_dim)

        Returns:
            Energy values of shape (batch_size,)
        """
        features = super().forward(x)
        energy = self.output_layer(features)
        return energy.squeeze(-1)

Convolutional Energy Models

For image data, convolutional architectures are more appropriate:

class ConvEnergyModel(nn.Module):
    """Convolutional energy model for image data."""

    def __init__(
        self,
        input_channels: int,
        image_size: int,
        channels: List[int] = [32, 64, 128, 256],
        kernel_size: int = 3,
        activation: Optional[nn.Module] = None
    ):
        """Initialize convolutional energy model.

        Args:
            input_channels: Number of input channels
            image_size: Size of input images (assumed square)
            channels: List of channel dimensions for conv layers
            kernel_size: Size of convolutional kernel
            activation: Activation function
        """
        super().__init__()
        self.input_channels = input_channels
        self.image_size = image_size
        self.activation = activation or nn.LeakyReLU(0.2)

        # Build convolutional layers
        layers = []
        in_channels = input_channels

        for out_channels in channels:
            layers.append(
                nn.Conv2d(
                    in_channels, 
                    out_channels, 
                    kernel_size=kernel_size,
                    stride=2,
                    padding=kernel_size // 2
                )
            )
            layers.append(self.activation)
            in_channels = out_channels

        self.conv_net = nn.Sequential(*layers)

        # Calculate feature size after convolutions
        feature_size = image_size // (2 ** len(channels))

        # Final layers
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels * feature_size * feature_size, 128),
            self.activation,
            nn.Linear(128, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass to compute energy.

        Args:
            x: Input tensor of shape (batch_size, channels, height, width)

        Returns:
            Energy values of shape (batch_size,)
        """
        # Ensure correct input shape
        if len(x.shape) == 2:
            x = x.view(-1, self.input_channels, self.image_size, self.image_size)

        features = self.conv_net(x)
        energy = self.fc(features)
        return energy.squeeze(-1)

Neural Energy Functions

Neural networks can be used to create energy functions:

from torchebm.core import EnergyFunction

class NeuralEnergyFunction(EnergyFunction):
    """Energy function implemented using a neural network."""

    def __init__(self, model: nn.Module):
        """Initialize neural energy function.

        Args:
            model: Neural network model
        """
        super().__init__()
        self.model = model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute energy values for inputs.

        Args:
            x: Input tensor

        Returns:
            Energy values
        """
        return self.model(x)

Advanced Architectures

Residual Blocks

class ResidualBlock(nn.Module):
    """Residual block for energy models."""

    def __init__(self, dim: int, activation: nn.Module = nn.ReLU()):
        """Initialize residual block.

        Args:
            dim: Feature dimension
            activation: Activation function
        """
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            activation,
            nn.Linear(dim, dim)
        )
        self.activation = activation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with residual connection.

        Args:
            x: Input tensor

        Returns:
            Output tensor
        """
        return x + self.block(x)

class ResNetEnergyModel(nn.Module):
    """ResNet-style energy model."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 128,
        n_blocks: int = 4,
        activation: nn.Module = nn.ReLU()
    ):
        """Initialize ResNet energy model.

        Args:
            input_dim: Input dimension
            hidden_dim: Hidden dimension
            n_blocks: Number of residual blocks
            activation: Activation function
        """
        super().__init__()

        # Input projection
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            activation
        )

        # Residual blocks
        self.blocks = nn.ModuleList([
            ResidualBlock(hidden_dim, activation)
            for _ in range(n_blocks)
        ])

        # Output projection
        self.output_proj = nn.Linear(hidden_dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through ResNet energy model.

        Args:
            x: Input tensor

        Returns:
            Energy values
        """
        h = self.input_proj(x)

        for block in self.blocks:
            h = block(h)

        energy = self.output_proj(h)
        return energy.squeeze(-1)

Integration with Trainers

Models are integrated with trainer classes:

class EBMTrainer:
    """Trainer for energy-based models."""

    def __init__(
        self,
        energy_function: EnergyFunction,
        sampler: "Sampler",
        optimizer: torch.optim.Optimizer,
        loss_fn: "Loss"
    ):
        """Initialize EBM trainer.

        Args:
            energy_function: Energy function to train
            sampler: Sampler for negative samples
            optimizer: Optimizer for model parameters
            loss_fn: Loss function
        """
        self.energy_function = energy_function
        self.sampler = sampler
        self.optimizer = optimizer
        self.loss_fn = loss_fn

    def train_step(
        self,
        pos_samples: torch.Tensor,
        neg_samples: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """Perform one training step.

        Args:
            pos_samples: Positive samples from data
            neg_samples: Optional negative samples

        Returns:
            Dictionary of metrics
        """
        # Generate negative samples if not provided
        if neg_samples is None:
            with torch.no_grad():
                neg_samples = self.sampler.sample(
                    n_samples=pos_samples.shape[0],
                    dim=pos_samples.shape[1]
                )

        # Zero gradients
        self.optimizer.zero_grad()

        # Compute loss
        loss, metrics = self.loss_fn(pos_samples, neg_samples)

        # Backward and optimize
        loss.backward()
        self.optimizer.step()

        return metrics

Performance Optimizations

Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

class MixedPrecisionEBMTrainer(EBMTrainer):
    """Trainer with mixed precision for faster training."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.scaler = GradScaler()

    def train_step(
        self,
        pos_samples: torch.Tensor,
        neg_samples: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """Perform one training step with mixed precision."""
        # Generate negative samples if not provided
        if neg_samples is None:
            with torch.no_grad():
                neg_samples = self.sampler.sample(
                    n_samples=pos_samples.shape[0],
                    dim=pos_samples.shape[1]
                )

        # Zero gradients
        self.optimizer.zero_grad()

        # Forward pass with mixed precision
        with autocast():
            loss, metrics = self.loss_fn(pos_samples, neg_samples)

        # Backward and optimize with gradient scaling
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.scaler.update()

        return metrics

Best Practices for Custom Models

When implementing custom models, follow these best practices:

Do

  • Subclass appropriate base classes
  • Handle device placement correctly
  • Use proper initialization
  • Consider normalization techniques
  • Document architecture clearly

Don't

  • Create overly complex architectures
  • Ignore numerical stability
  • Forget to validate inputs
  • Mix different PyTorch versions
  • Ignore gradient flow issues

Custom Model Example

class CustomEnergyModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.SiLU(),  # SiLU/Swish activation
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Initialize weights properly
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        return self.network(x).squeeze(-1)

Resources