Skip to content

FeedForward

Methods and Attributes

Bases: Module

Source code in torchebm/models/components/transformer.py
class FeedForward(nn.Module):
    def __init__(self, embed_dim: int, mlp_ratio: float = 4.0, dropout: float = 0.0):
        super().__init__()
        hidden = int(embed_dim * mlp_ratio)
        self.net = nn.Sequential(
            nn.Linear(embed_dim, hidden, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Dropout(dropout),
            nn.Linear(hidden, embed_dim, bias=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

net instance-attribute

net = Sequential(Linear(embed_dim, hidden, bias=True), GELU(approximate='tanh'), Dropout(dropout), Linear(hidden, embed_dim, bias=True))

forward

forward(x: Tensor) -> torch.Tensor
Source code in torchebm/models/components/transformer.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.net(x)