import inspect from argparse import Namespace import warnings import torch from torch import nn from einops import rearrange, repeat from ml_lib.metrics.multi_class_classification import MultiClassScores from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x) from util.module_mixins import CombinedModelMixins MIN_NUM_PATCHES = 16 try: from performer_pytorch import Performer class VisualPerformer(CombinedModelMixins, LightningBaseModule ): def __init__(self, in_shape, n_classes, weight_init, activation, embedding_size, heads, attn_depth, patch_size, use_residual, use_bias, use_norm, dropout, lat_dim, loss, scheduler, lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval): # TODO: Move this to parent class, or make it much easieer to access... But How... a = dict(locals()) params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'} super(VisualPerformer, self).__init__(params) self.in_shape = in_shape assert len(self.in_shape) == 3, 'There need to be three Dimensions' channels, height, width = self.in_shape # Model Paramters # ============================================================================= # Additional parameters self.embed_dim = self.params.embedding_size # Automatic Image Shaping self.patch_size = self.params.patch_size image_size = (max(height, width) // self.patch_size) * self.patch_size self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size # This should be obsolete assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size' num_patches = (self.image_size // self.patch_size) ** 2 patch_dim = channels * self.patch_size ** 2 assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \ f'attention. Try decreasing your patch size' # Correct the Embedding Dim if not self.embed_dim % self.params.heads == 0: self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads message = ('Embedding Dimension was fixed to be devideable by the number' + f' of attention heads, is now: {self.embed_dim}') for func in print, warnings.warn: func(message) # Utility Modules self.autopad = AutoPadToShape((self.image_size, self.image_size)) # Modules with Parameters self.performer = Performer( dim=self.embed_dim, # dimension depth=self.params.attn_depth, # layers heads=self.params.heads, # heads causal=True, # auto-regressive or not nb_features=None, # 256, # number of random features, if not set, will default to # (d * log(d)), where d is the dimension of each head feature_redraw_interval=1000, # how frequently to redraw the projection matrix, # the more frequent, the slower the training generalized_attention=False, # defaults to softmax approximation, # but can be set to True for generalized attention kernel_fn=self.params.activation(), # the kernel function to be used, # if generalized attention is turned on, defaults to Relu reversible=True, # reversible layers, from Reformer paper ff_chunks=10, # chunk feedforward layer, from Reformer paper use_scalenorm=False, # use scale norm, from 'Transformers without Tears' paper use_rezero=False, # use rezero, from 'Rezero is all you need' paper ff_glu=True, # use GLU variant for feedforward ff_dropout=self.params.dropout, # feedforward dropout attn_dropout=self.params.dropout, # post-attn dropout local_attn_heads=self.params.heads // 2, # 4 heads are local attention, 4 others are global performers local_window_size=(patch_dim // self.params.heads) * 2 # window size of local attention ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim)) self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \ else F_x(self.embed_dim) self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim)) self.dropout = nn.Dropout(self.params.dropout) self.to_cls_token = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(self.embed_dim), nn.Linear(self.embed_dim, self.params.lat_dim), nn.GELU(), nn.Dropout(self.params.dropout), nn.Linear(self.params.lat_dim, n_classes), nn.Softmax() ) def forward(self, x): """ :param x: the sequence to the encoder (required). :return: """ tensor = self.autopad(x) p = self.params.patch_size tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) tensor = self.patch_to_embedding(tensor) b, n, _ = tensor.shape cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) tensor = torch.cat((cls_tokens, tensor), dim=1) tensor += self.pos_embedding[:, :(n + 1)] tensor = self.dropout(tensor) tensor = self.performer(tensor) tensor = self.to_cls_token(tensor[:, 0]) tensor = self.mlp_head(tensor) return Namespace(main_out=tensor) except ImportError: # pragma: do not provide model class print('You want to use `performer_pytorch` plugins which are not installed yet,' # pragma: no-cover ' install it with `pip install performer_pytorch`.')