from argparse import Namespace import warnings import torch from torch import nn from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x, SlidingWindow) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, DatasetMixin, BaseDataloadersMixin, BaseTestMixin) MIN_NUM_PATCHES = 16 class VerticalVisualTransformer(DatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, BaseTestMixin, BaseOptimizerMixin, LightningBaseModule ): def __init__(self, hparams): super(VerticalVisualTransformer, self).__init__(hparams) # Dataset # ============================================================================= self.dataset = self.build_dataset() self.in_shape = self.dataset.train_dataset.sample_shape assert len(self.in_shape) == 3, 'There need to be three Dimensions' channels, height, width = self.in_shape # Model Paramters # ============================================================================= # Additional parameters self.embed_dim = self.params.embedding_size self.patch_size = self.params.patch_size self.height = height self.width = width self.channels = channels self.new_width = ((self.width - self.patch_size)//1) + 1 num_patches = self.new_width - (self.patch_size // 2) patch_dim = channels * self.patch_size * self.height assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \ f'attention. Try decreasing your patch size' # Correct the Embedding Dim if not self.embed_dim % self.params.heads == 0: self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads message = ('Embedding Dimension was fixed to be devideable by the number' + f' of attention heads, is now: {self.embed_dim}') for func in print, warnings.warn: func(message) # Utility Modules self.autopad = AutoPadToShape((self.height, self.new_width)) self.dropout = nn.Dropout(self.params.dropout) self.slider = SlidingWindow((channels, *self.autopad.target_shape), (self.height, self.patch_size), keepdim=False) # Modules with Parameters self.transformer = TransformerModule(in_shape=self.embed_dim, hidden_size=self.params.lat_dim, n_heads=self.params.heads, num_layers=self.params.attn_depth, dropout=self.params.dropout, use_norm=self.params.use_norm, activation=self.params.activation_as_string ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \ else F_x(self.embed_dim) self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim)) self.to_cls_token = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(self.embed_dim), nn.Linear(self.embed_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), nn.Linear(self.params.lat_dim, 1), nn.Sigmoid() ) def forward(self, x, mask=None): """ :param x: the sequence to the encoder (required). :param mask: the mask for the src sequence (optional). :return: """ tensor = self.autopad(x) tensor = self.slider(tensor) tensor = self.patch_to_embedding(tensor) b, n, _ = tensor.shape # cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) cls_tokens = self.cls_token.repeat((b, 1, 1)) tensor = torch.cat((cls_tokens, tensor), dim=1) tensor += self.pos_embedding[:, :(n + 1)] tensor = self.dropout(tensor) tensor = self.transformer(tensor, mask) tensor = self.to_cls_token(tensor[:, 0]) tensor = self.mlp_head(tensor) return Namespace(main_out=tensor)