Design Principles¶
Project Philosophy
TorchEBM is built on a set of core design principles that guide its development. Understanding these principles will help you contribute in a way that aligns with the project's vision.
Core Philosophy¶
TorchEBM aims to be:
- 
Performant 
 High-performance implementations that leverage PyTorch's capabilities and CUDA acceleration. 
- 
Modular 
 Components that can be easily combined, extended, and customized. 
- 
Intuitive 
 Clear, well-documented APIs that are easy to understand and use. 
- 
Educational 
 Serves as both a practical tool and a learning resource for energy-based modeling. 
Key Design Patterns¶
Composable Base Classes¶
TorchEBM is built around a set of extensible base classes that provide common interface:
class BaseEnergyFunction(nn.Module):
    """Base class for all energy functions."""
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Compute energy for input x."""
        raise NotImplementedError
    def score(self, x: torch.Tensor) -> torch.Tensor:
        """Compute score (gradient of energy) for input x."""
        x = x.requires_grad_(True)
        energy = self.forward(x)
        return torch.autograd.grad(energy.sum(), x, create_graph=True)[0]
This design allows for:
- Composition: Combining energy functions via addition, multiplication, etc.
- Extension: Creating new energy functions by subclassing
- Integration: Using energy functions with any sampler that follows the interface
Factory Methods¶
Factory methods create configured instances with sensible defaults:
@classmethod
def create_standard(cls, dim: int = 2) -> 'GaussianEnergy':
    """Create a standard Gaussian energy function."""
    return cls(mean=torch.zeros(dim), cov=torch.eye(dim))
Configuration through Constructor¶
Classes are configured through their constructor rather than setter methods:
This approach:
- Makes the configuration explicit and clear
- Encourages immutability of key parameters
- Simplifies object creation and usage
Method Chaining¶
Methods return the object itself when appropriate to allow method chaining:
result = (
    sampler
    .set_device("cuda" if torch.cuda.is_available() else "cpu")
    .set_seed(42)
    .sample(dim=2, n_steps=1000)
)
Lazily-Evaluated Operations¶
Computations are performed lazily when possible to avoid unnecessary work:
# Create a sampler with a sampling trajectory
sampler = LangevinDynamics(energy_fn)
trajectory = sampler.sample_trajectory(dim=2, n_steps=1000)
# Compute statistics only when needed
mean = trajectory.mean()  # Computation happens here
variance = trajectory.variance()  # Computation happens here
Architecture Principles¶
Separation of Concerns¶
Components have clearly defined responsibilities:
- Energy Functions: Define the energy landscape
- Samplers: Generate samples from energy functions
- Losses: Train energy functions from data
- Models: Parameterize energy functions using neural networks
- Utils: Provide supporting functionality
Minimizing Dependencies¶
Each module has minimal dependencies on other modules:
- Core modules (e.g., core,samplers) don't depend on higher-level modules
- Utility modules are designed to be used by all other modules
- CUDA implementations are separated to allow for CPU-only usage
Consistent Error Handling¶
Error handling follows consistent patterns:
- Use descriptive error messages that suggest solutions
- Validate inputs early with helpful validation errors
- Provide debug information when operations fail
def validate_dimensions(tensor: torch.Tensor, expected_dims: int) -> None:
    """Validate that tensor has the expected number of dimensions."""
    if tensor.dim() != expected_dims:
        raise ValueError(
            f"Expected tensor with {expected_dims} dimensions, "
            f"but got tensor with batch_shape {tensor.shape}"
        )
Consistent API Design¶
APIs are designed consistently across the library:
- Similar operations have similar interfaces
- Parameters follow consistent naming conventions
- Return types are consistent and well-documented
Progressive Disclosure¶
Simple use cases are simple, while advanced functionality is available but not required:
# Simple usage
sampler = LangevinDynamics(energy_fn)
samples = sampler.sample(n_samples=1000)
# Advanced usage
sampler = LangevinDynamics(
    energy_fn,
    step_size=0.01,
    noise_scale=1.0,
    step_size_schedule=LinearSchedule(0.01, 0.001),
    metropolis_correction=True
)
samples = sampler.sample(
    n_samples=1000,
    initial_samples=initial_x,
    callback=logging_callback
)
Implementation Principles¶
PyTorch First¶
TorchEBM is built on PyTorch and follows PyTorch patterns:
- Use PyTorch's tensor operations whenever possible
- Follow PyTorch's model design patterns (e.g., nn.Module)
- Leverage PyTorch's autograd for gradient computation
- Support both CPU and CUDA execution
Vectorized Operations¶
Operations are vectorized where possible for efficiency:
# Good: Vectorized operations
def compute_pairwise_distances(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.cdist(x, y, p=2)
# Avoid: Explicit loops
def compute_pairwise_distances_slow(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    result = torch.zeros(x.size(0), y.size(0))
    for i in range(x.size(0)):
        for j in range(y.size(0)):
            result[i, j] = torch.norm(x[i] - y[j])
    return result
CUDA Optimization¶
Performance-critical operations are optimized with CUDA when appropriate:
- CPU implementations as fallback
- CUDA implementations for performance
- Automatic selection based on available hardware
Type Annotations¶
Code uses type annotations for clarity and static analysis:
def sample_chain(
    self,
    dim: int,
    n_steps: int,
    n_samples: int = 1,
    initial_samples: Optional[torch.Tensor] = None,
    return_trajectory: bool = False
) -> Union[torch.Tensor, Trajectory]:
    """Generate samples using a Markov chain."""
    # Implementation
Testing Principles¶
- Unit Testing: Individual components are thoroughly tested
- Integration Testing: Component interactions are tested
- Property Testing: Properties of algorithms are tested
- Numerical Testing: Numerical algorithms are tested for stability and accuracy
Documentation Principles¶
Documentation is comprehensive and includes:
- API Documentation: Clear documentation of all public APIs
- Tutorials: Step-by-step guides for common tasks
- Examples: Real-world examples of using the library
- Theory: Explanations of the underlying mathematical concepts
Future Compatibility¶
TorchEBM is designed with future compatibility in mind:
- API Stability: Breaking changes are minimized and clearly documented
- Feature Flags: Experimental features are clearly marked
- Deprecation Warnings: Deprecated features emit warnings before removal
Contributing Guidelines¶
When contributing to TorchEBM, adhere to these design principles:
- Make sure new components follow existing patterns
- Keep interfaces consistent with the rest of the library
- Write thorough tests for new functionality
- Document public APIs clearly
- Optimize for readability and maintainability
Design Example: Adding a New Sampler
When adding a new sampler:
- Subclass the Samplerbase class
- Implement required methods (sample,sample_chain)
- Follow the existing parameter naming conventions
- Add comprehensive documentation
- Write tests that verify the sampler's properties
- Optimize performance-critical sections
- 
Project Structure 
 Understand how the project is organized. 
- 
Core Components 
 Learn about the core components in detail. 
- 
Code Style 
 Follow the project's coding standards.