This guide explains the implementation of neural network models in TorchEBM, including architecture designs, training workflows, and integration with energy functions.
importtorchimporttorch.nnasnnfromtypingimportTuple,List,Dict,Any,Optional,UnionclassBaseModel(nn.Module):"""Base class for all neural network models."""def__init__(self,input_dim:int,hidden_dims:List[int],activation:Optional[nn.Module]=None):"""Initialize base model. Args: input_dim: Input dimension hidden_dims: List of hidden dimensions activation: Activation function """super().__init__()self.input_dim=input_dimself.hidden_dims=hidden_dimsself.activation=activationornn.ReLU()# Build network architectureself._build_network()def_build_network(self):"""Build the neural network architecture."""layers=[]# Input layerprev_dim=self.input_dim# Hidden layersfordiminself.hidden_dims:layers.append(nn.Linear(prev_dim,dim))layers.append(self.activation)prev_dim=dimself.network=nn.Sequential(*layers)defforward(self,x:torch.Tensor)->torch.Tensor:"""Forward pass through the network. Args: x: Input tensor of batch_shape (batch_size, input_dim) Returns: Output tensor """returnself.network(x)
classMLPEnergyModel(BaseModel):"""Multi-layer perceptron energy model."""def__init__(self,input_dim:int,hidden_dims:List[int],activation:Optional[nn.Module]=None,use_spectral_norm:bool=False):"""Initialize MLP energy model. Args: input_dim: Input dimension hidden_dims: List of hidden dimensions activation: Activation function use_spectral_norm: Whether to use spectral normalization """super().__init__(input_dim,hidden_dims,activation)self.output_layer=nn.Linear(hidden_dims[-1],1)# Apply spectral normalization if requestedifuse_spectral_norm:self._apply_spectral_norm()def_apply_spectral_norm(self):"""Apply spectral normalization to all linear layers."""forname,moduleinself.named_modules():ifisinstance(module,nn.Linear):setattr(self,name,nn.utils.spectral_norm(module))defforward(self,x:torch.Tensor)->torch.Tensor:"""Forward pass to compute energy. Args: x: Input tensor of batch_shape (batch_size, input_dim) Returns: Energy values of batch_shape (batch_size,) """features=super().forward(x)energy=self.output_layer(features)returnenergy.squeeze(-1)
classConvEnergyModel(nn.Module):"""Convolutional energy model for image data."""def__init__(self,input_channels:int,image_size:int,channels:List[int]=[32,64,128,256],kernel_size:int=3,activation:Optional[nn.Module]=None):"""Initialize convolutional energy model. Args: input_channels: Number of input channels image_size: Size of input images (assumed square) channels: List of channel dimensions for conv layers kernel_size: Size of convolutional kernel activation: Activation function """super().__init__()self.input_channels=input_channelsself.image_size=image_sizeself.activation=activationornn.LeakyReLU(0.2)# Build convolutional layerslayers=[]in_channels=input_channelsforout_channelsinchannels:layers.append(nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=2,padding=kernel_size//2))layers.append(self.activation)in_channels=out_channelsself.conv_net=nn.Sequential(*layers)# Calculate feature size after convolutionsfeature_size=image_size//(2**len(channels))# Final layersself.fc=nn.Sequential(nn.Flatten(),nn.Linear(in_channels*feature_size*feature_size,128),self.activation,nn.Linear(128,1))defforward(self,x:torch.Tensor)->torch.Tensor:"""Forward pass to compute energy. Args: x: Input tensor of batch_shape (batch_size, channels, height, width) Returns: Energy values of batch_shape (batch_size,) """# Ensure correct input batch_shapeiflen(x.shape)==2:x=x.view(-1,self.input_channels,self.image_size,self.image_size)features=self.conv_net(x)energy=self.fc(features)returnenergy.squeeze(-1)
fromtorchebm.coreimportBaseEnergyFunctionclassNeuralEnergyFunction(BaseEnergyFunction):"""Energy function implemented using a neural network."""def__init__(self,model:nn.Module):"""Initialize neural energy function. Args: model: Neural network model """super().__init__()self.model=modeldefforward(self,x:torch.Tensor)->torch.Tensor:"""Compute energy values for inputs. Args: x: Input tensor Returns: Energy values """returnself.model(x)
classEBMTrainer:"""Trainer for energy-based models."""def__init__(self,energy_function:BaseEnergyFunction,sampler:"Sampler",optimizer:torch.optim.Optimizer,loss_fn:"BaseLoss"):"""Initialize EBM trainer. Args: energy_function: Energy function to train sampler: Sampler for negative samples optimizer: Optimizer for model parameters loss_fn: BaseLoss function """self.energy_function=energy_functionself.sampler=samplerself.optimizer=optimizerself.loss_fn=loss_fndeftrain_step(self,pos_samples:torch.Tensor,neg_samples:Optional[torch.Tensor]=None)->Dict[str,torch.Tensor]:"""Perform one training step. Args: pos_samples: Positive samples from data neg_samples: Optional negative samples Returns: Dictionary of metrics """# Generate negative samples if not providedifneg_samplesisNone:withtorch.no_grad():neg_samples=self.sampler.sample(n_samples=pos_samples.shape[0],dim=pos_samples.shape[1])# Zero gradientsself.optimizer.zero_grad()# Compute lossloss,metrics=self.loss_fn(pos_samples,neg_samples)# Backward and optimizeloss.backward()self.optimizer.step()returnmetrics
fromtorch.cuda.ampimportautocast,GradScalerclassMixedPrecisionEBMTrainer(EBMTrainer):"""Trainer with mixed precision for faster training."""def__init__(self,*args,**kwargs):super().__init__(*args,**kwargs)self.scaler=GradScaler()deftrain_step(self,pos_samples:torch.Tensor,neg_samples:Optional[torch.Tensor]=None)->Dict[str,torch.Tensor]:"""Perform one training step with mixed precision."""# Generate negative samples if not providedifneg_samplesisNone:withtorch.no_grad():neg_samples=self.sampler.sample(n_samples=pos_samples.shape[0],dim=pos_samples.shape[1])# Zero gradientsself.optimizer.zero_grad()# Forward pass with mixed precisionwithautocast():loss,metrics=self.loss_fn(pos_samples,neg_samples)# Backward and optimize with gradient scalingself.scaler.scale(loss).backward()self.scaler.step(self.optimizer)self.scaler.update()returnmetrics