Code Style Guide¶
Consistent Style
Following a consistent code style ensures our codebase remains readable and maintainable. This guide outlines the style conventions used in TorchEBM.
Python Style Guidelines¶
TorchEBM follows PEP 8 with some project-specific guidelines.
Automatic Formatting¶
We use several tools to automatically format and check our code:
-
Black
Automatic code formatter with a focus on consistency.
-
isort
Sorts imports alphabetically and separates them into sections.
-
Flake8
Linter to catch logical and stylistic issues.
Code Structure¶
def function_name(
param1: type,
param2: type,
param3: Optional[type] = None
) -> ReturnType:
"""Short description of the function.
More detailed explanation if needed.
Args:
param1: Description of parameter 1
param2: Description of parameter 2
param3: Description of parameter 3
Returns:
Description of the return value
Raises:
ExceptionType: When and why this exception is raised
"""
# Function implementation
pass
class ClassName(BaseClass):
"""Short description of the class.
More detailed explanation if needed.
Args:
attr1: Description of attribute 1
attr2: Description of attribute 2
"""
def __init__(
self,
attr1: type,
attr2: type = default_value
):
"""Initialize the class.
Args:
attr1: Description of attribute 1
attr2: Description of attribute 2
"""
self.attr1 = attr1
self.attr2 = attr2
def method_name(self, param: type) -> ReturnType:
"""Short description of the method.
Args:
param: Description of parameter
Returns:
Description of the return value
"""
# Method implementation
pass
Naming Conventions¶
Functions and Variables¶
Use snake_case
for functions and variables:
Constants¶
Use UPPER_CASE
for constants:
Documentation Style¶
TorchEBM uses Google-style docstrings for all code documentation.
Docstring Example
def sample_chain(
self,
dim: int,
n_steps: int,
n_samples: int = 1
) -> torch.Tensor:
"""Generate samples using a Markov chain of specified length.
Args:
dim: Dimensionality of samples
n_steps: Number of steps in the chain
n_samples: Number of parallel chains to run
Returns:
Tensor of shape (n_samples, dim) containing final samples
Examples:
>>> energy_fn = GaussianEnergy(torch.zeros(2), torch.eye(2))
>>> sampler = LangevinDynamics(energy_fn, step_size=0.01)
>>> samples = sampler.sample_chain(dim=2, n_steps=100, n_samples=10)
"""
Type Annotations¶
We use Python's type hints to improve code readability and enable static type checking:
from typing import Optional, List, Union, Dict, Tuple, Callable
def function(
tensor: torch.Tensor,
scale: float = 1.0,
use_cuda: bool = False,
callback: Optional[Callable[[torch.Tensor], None]] = None
) -> Tuple[torch.Tensor, float]:
# Implementation
pass
CUDA Code Style¶
For CUDA extensions, we follow these conventions:
Imports Organization¶
Organize imports in the following order:
- Standard library imports
- Related third-party imports
- Local application/library specific imports
# Standard library
import os
import sys
from typing import Optional, Dict
# Third-party
import numpy as np
import torch
import torch.nn as nn
# Local application
from torchebm.core import EnergyFunction
from torchebm.utils import get_device
Comments¶
- Use comments sparingly - prefer self-documenting code with clear variable names
- Add comments for complex algorithms or non-obvious implementations
- Update comments when you change code
Good Comments Example
Pre-commit Hooks¶
TorchEBM uses pre-commit hooks to enforce code style. Make sure to install them as described in the Development Setup guide.
Recommended Editor Settings¶
- Install Black and isort plugins
- Configure Code Style for Python to match PEP 8
- Set Black as the external formatter
- Enable "Reformat code on save"
- Configure isort for import optimization
Style Enforcement¶
Our CI pipeline checks for style compliance. Pull requests failing style checks will be automatically rejected.