import variables as V import torch from torch import nn from ml_lib.modules.blocks import TransformerModule from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape) from util.module_mixins import (BaseOptimizerMixin, BaseTrainMixin, BaseValMixin, BinaryMaskDatasetMixin, BaseDataloadersMixin) MIN_NUM_PATCHES = 16 class VisualTransformer(BinaryMaskDatasetMixin, BaseDataloadersMixin, BaseTrainMixin, BaseValMixin, BaseOptimizerMixin, LightningBaseModule ): def __init__(self, hparams): super(VisualTransformer, self).__init__(hparams) self.in_shape = self.dataset.train_dataset.sample_shape assert len(self.in_shape) == 3, 'There need to be three Dimensions' channels, height, width = self.in_shape # Automatic Image Shaping image_size = (max(height, width) // self.params.patch_size) * self.params.patch_size self.image_size = image_size + self.params.patch_size if image_size < max(height, width) else image_size # This should be obsolete assert self.image_size % self.params.patch_size == 0, 'image dimensions must be divisible by the patch size' num_patches = (self.image_size // self.params.patch_size) ** 2 patch_dim = channels * self.params.patch_size ** 2 assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' \ f'attention. Try decreasing your patch size' # Dataset # ============================================================================= self.dataset = self.build_dataset() # Model Paramters # ============================================================================= # Additional parameters self.attention_dim = self.params.features # Utility Modules self.autopad = AutoPadToShape((self.image_size, self.image_size)) # Modules with Parameters self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.attention_dim), False) self.embedding = nn.Linear(patch_dim, self.attention_dim) self.cls_token = nn.Parameter(torch.randn(1, 1, self.attention_dim), False) self.dropout = nn.Dropout(self.params.dropout) self.transformer = TransformerModule(self.attention_dim, self.params.attn_depth, self.params.heads, self.params.lat_dim, self.params.dropout) self.to_cls_token = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(self.attention_dim), nn.Linear(self.attention_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), nn.Linear(self.params.lat_dim, V.NUM_CLASSES) ) def forward(self, x, mask=None): """ :param tensor: the sequence to the encoder (required). :param mask: the mask for the src sequence (optional). :return: """ p = self.params.patch_size # 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p tensor = torch.reshape(x, (-1, self.image_size * self.image_size, p * p * self.in_shape[0])) tensor = self.patch_to_embedding(tensor) b, n, _ = tensor.shape # '() n d -> b n d', b = b cls_tokens = tensor.repeat(self.cls_token, b) tensor = torch.cat((cls_tokens, tensor), dim=1) tensor += self.pos_embedding[:, :(n + 1)] tensor = self.dropout(tensor) tensor = self.transformer(tensor, mask) tensor = self.to_cls_token(tensor[:, 0]) tensor = self.mlp_head(tensor) return tensor