import inspect from argparse import Namespace import torch from torch import nn from einops import rearrange, repeat from ml_lib.metrics.binary_class_classifictaion import BinaryScores from ml_lib.metrics.multi_class_classification import MultiClassScores from ml_lib.modules.blocks import (TransformerModule, F_x) from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape) from util.module_mixins import CombinedModelMixins MIN_NUM_PATCHES = 16 class VisualTransformer(CombinedModelMixins, LightningBaseModule ): def __init__(self, in_shape, n_classes, weight_init, activation, embedding_size, heads, attn_depth, patch_size, use_residual, variable_length, use_bias, use_norm, dropout, lat_dim, loss, scheduler, mlp_dim, head_dim, lr, weight_decay, sto_weight_avg, lr_scheduler_parameter, opt_reset_interval, return_logits=False): # TODO: Move this to parent class, or make it much easier to access... But How... a = dict(locals()) params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'} super(VisualTransformer, self).__init__(params) self.in_shape = in_shape assert len(self.in_shape) == 3, 'There need to be three Dimensions' channels, height, width = self.in_shape # Model Paramters # ============================================================================= # Additional parameters self.embed_dim = self.params.embedding_size # Automatic Image Shaping self.patch_size = self.params.patch_size new_height = (height // self.patch_size) * self.patch_size new_height = new_height + self.patch_size if new_height < height else new_height new_width = (width // self.patch_size) * self.patch_size new_width = new_width + self.patch_size if new_width < width else new_width num_patches = (new_height // self.patch_size) * (new_width // self.patch_size) patch_dim = channels * self.patch_size ** 2 assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \ f'attention. Try decreasing your patch size' # Utility Modules self.autopad = AutoPadToShape((new_height, new_width)) # Modules with Parameters self.transformer = TransformerModule(in_shape=self.embed_dim, mlp_dim=self.params.mlp_dim, head_dim=self.params.head_dim, heads=self.params.heads, depth=self.params.attn_depth, dropout=self.params.dropout, use_norm=self.params.use_norm, activation=self.params.activation, use_residual=self.params.use_residual ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim)) self.dropout = nn.Dropout(self.params.dropout) self.to_cls_token = nn.Identity() logits = self.params.n_classes if self.params.n_classes > 2 else 1 if return_logits: outbound_activation = nn.Identity() else: outbound_activation = nn.Softmax() if logits > 1 else nn.Sigmoid() self.mlp_head = nn.Sequential( nn.LayerNorm(self.embed_dim), nn.Linear(self.embed_dim, self.params.lat_dim), self.params.activation(), nn.Dropout(self.params.dropout), nn.Linear(self.params.lat_dim, logits), outbound_activation ) def forward(self, x, mask=None, return_attn_weights=False): """ :param x: the sequence to the encoder (required). :param mask: the mask for the src sequence (optional). :return: """ tensor = self.autopad(x) p = self.params.patch_size tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) b, n, _ = tensor.shape if self.params.variable_length and mask is None: # mask lengths = torch.count_nonzero(tensor, dim=-1) mask = (lengths == torch.zeros_like(lengths)) # CLS-token awareness # mask = torch.cat((torch.zeros(b, 1), mask), dim=-1) # mask = repeat(mask, 'b n -> b h n', h=self.params.heads) else: mask = mask tensor = self.patch_to_embedding(tensor) cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) tensor = torch.cat((cls_tokens, tensor), dim=1) tensor += self.pos_embedding[:, :(n + 1)] tensor = self.dropout(tensor) if return_attn_weights: tensor, attn_weights = self.transformer(tensor, mask, return_attn_weights) else: attn_weights = None tensor = self.transformer(tensor, mask) tensor = self.to_cls_token(tensor[:, 0]) tensor = self.mlp_head(tensor) return Namespace(main_out=tensor, attn_weights=attn_weights) def additional_scores(self, outputs): if self.params.n_classes <= 2: return BinaryScores(self)(outputs) else: return MultiClassScores(self)(outputs)