Skip to content

BaseModel

Methods and Attributes

Bases: ABC

Base class for models.

Parameters:

Name Type Description Default
energy_function EnergyFunction

Energy function to sample from.

required
sampler BaseSampler

Sampler to use for sampling.

required

Methods:

Name Description
energy

Compute the energy of the input.

sample

Sample from the model.

train_step

Perform a single training step

Source code in torchebm/models/base_model.py
class BaseModel(ABC):
    """
    Base class for models.

    Args:
        energy_function (EnergyFunction): Energy function to sample from.
        sampler (BaseSampler): Sampler to use for sampling.

    Methods:
        energy(x): Compute the energy of the input.
        sample(num_samples): Sample from the model.
        train_step(real_data): Perform a single training step
    """

    def __init__(self, energy_function: EnergyFunction, sampler: BaseSampler):
        self.energy_function = energy_function
        self.sampler = sampler

    @abstractmethod
    def energy(self, x: torch.Tensor) -> torch.Tensor:
        pass

    @abstractmethod
    def sample(self, num_samples: int) -> torch.Tensor:
        pass

    @abstractmethod
    def train_step(self, real_data: torch.Tensor) -> dict:
        pass

energy_function instance-attribute

energy_function = energy_function

sampler instance-attribute

sampler = sampler

energy abstractmethod

energy(x: Tensor) -> torch.Tensor
Source code in torchebm/models/base_model.py
@abstractmethod
def energy(self, x: torch.Tensor) -> torch.Tensor:
    pass

sample abstractmethod

sample(num_samples: int) -> torch.Tensor
Source code in torchebm/models/base_model.py
@abstractmethod
def sample(self, num_samples: int) -> torch.Tensor:
    pass

train_step abstractmethod

train_step(real_data: Tensor) -> dict
Source code in torchebm/models/base_model.py
@abstractmethod
def train_step(self, real_data: torch.Tensor) -> dict:
    pass