From 2c9cb2e94a21634edd6a8630f53283060b0d7557 Mon Sep 17 00:00:00 2001
From: Steffen <steffen.illium@ifi.lmu.de>
Date: Thu, 18 Mar 2021 12:12:43 +0100
Subject: [PATCH] Small bugfixes

---
 main.py             |  10 ++-
 models/performer.py | 197 +++++++++++++++++++++++---------------------
 2 files changed, 108 insertions(+), 99 deletions(-)

diff --git a/main.py b/main.py
index 560827a..85cd039 100644
--- a/main.py
+++ b/main.py
@@ -10,12 +10,16 @@ from ml_lib.utils.config import parse_comandline_args_add_defaults
 from ml_lib.utils.loggers import Logger
 
 import variables as v
+from ml_lib.utils.tools import fix_all_random_seeds
 
 warnings.filterwarnings('ignore', category=FutureWarning)
 warnings.filterwarnings('ignore', category=UserWarning)
 
 
-def run_lightning_loop(h_params, data_class, model_class, additional_callbacks=None):
+def run_lightning_loop(h_params, data_class, model_class, seed=69, additional_callbacks=None):
+
+    fix_all_random_seeds(seed)
+
     with Logger.from_argparse_args(h_params) as logger:
         # Callbacks
         # =============================================================================
@@ -79,13 +83,13 @@ def run_lightning_loop(h_params, data_class, model_class, additional_callbacks=N
 
 if __name__ == '__main__':
     # Parse comandline args, read config and get model
-    cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
+    cmd_args, found_data_class, found_model_class, found_seed = parse_comandline_args_add_defaults('_parameters.ini')
 
     # To NameSpace
     hparams = Namespace(**cmd_args)
 
     # Start
     # -----------------
-    run_lightning_loop(hparams, found_data_class, found_model_class)
+    run_lightning_loop(hparams, found_data_class, found_model_class, found_seed)
     print('done')
     pass
diff --git a/models/performer.py b/models/performer.py
index 88f525b..2d45112 100644
--- a/models/performer.py
+++ b/models/performer.py
@@ -4,7 +4,6 @@ from argparse import Namespace
 import warnings
 
 import torch
-from performer_pytorch import Performer
 from torch import nn
 
 from einops import rearrange, repeat
@@ -15,119 +14,125 @@ from util.module_mixins import CombinedModelMixins
 
 MIN_NUM_PATCHES = 16
 
+try:
+    from performer_pytorch import Performer
 
-class VisualPerformer(CombinedModelMixins,
-                      LightningBaseModule
-                      ):
+    class VisualPerformer(CombinedModelMixins,
+                          LightningBaseModule
+                          ):
 
-    def __init__(self, in_shape, n_classes, weight_init, activation,
-                 embedding_size, heads, attn_depth, patch_size, use_residual,
-                 use_bias, use_norm, dropout, lat_dim, loss, scheduler,
-                 lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
+        def __init__(self, in_shape, n_classes, weight_init, activation,
+                     embedding_size, heads, attn_depth, patch_size, use_residual,
+                     use_bias, use_norm, dropout, lat_dim, loss, scheduler,
+                     lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):
 
-        # TODO: Move this to parent class, or make it much easieer to access... But How...
-        a = dict(locals())
-        params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
-        super(VisualPerformer, self).__init__(params)
+            # TODO: Move this to parent class, or make it much easieer to access... But How...
+            a = dict(locals())
+            params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
+            super(VisualPerformer, self).__init__(params)
 
-        self.in_shape = in_shape
-        assert len(self.in_shape) == 3, 'There need to be three Dimensions'
-        channels, height, width = self.in_shape
+            self.in_shape = in_shape
+            assert len(self.in_shape) == 3, 'There need to be three Dimensions'
+            channels, height, width = self.in_shape
 
-        # Model Paramters
-        # =============================================================================
-        # Additional parameters
-        self.embed_dim = self.params.embedding_size
+            # Model Paramters
+            # =============================================================================
+            # Additional parameters
+            self.embed_dim = self.params.embedding_size
 
-        # Automatic Image Shaping
-        self.patch_size = self.params.patch_size
-        image_size = (max(height, width) // self.patch_size) * self.patch_size
-        self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
+            # Automatic Image Shaping
+            self.patch_size = self.params.patch_size
+            image_size = (max(height, width) // self.patch_size) * self.patch_size
+            self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size
 
-        # This should be obsolete
-        assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
+            # This should be obsolete
+            assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
 
-        num_patches = (self.image_size // self.patch_size) ** 2
-        patch_dim = channels * self.patch_size ** 2
-        assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
-                                               f'attention. Try decreasing your patch size'
+            num_patches = (self.image_size // self.patch_size) ** 2
+            patch_dim = channels * self.patch_size ** 2
+            assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
+                                                   f'attention. Try decreasing your patch size'
 
-        # Correct the Embedding Dim
-        if not self.embed_dim % self.params.heads == 0:
-            self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
-            message = ('Embedding Dimension was fixed to be devideable by the number' +
-                       f' of attention heads, is now: {self.embed_dim}')
-            for func in print, warnings.warn:
-                func(message)
+            # Correct the Embedding Dim
+            if not self.embed_dim % self.params.heads == 0:
+                self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
+                message = ('Embedding Dimension was fixed to be devideable by the number' +
+                           f' of attention heads, is now: {self.embed_dim}')
+                for func in print, warnings.warn:
+                    func(message)
 
-        # Utility Modules
-        self.autopad = AutoPadToShape((self.image_size, self.image_size))
+            # Utility Modules
+            self.autopad = AutoPadToShape((self.image_size, self.image_size))
 
-        # Modules with Parameters
-        self.performer = Performer(
-            dim=self.embed_dim,                         # dimension
-            depth=self.params.attn_depth,               # layers
-            heads=self.params.heads,                    # heads
-            causal=True,                                # auto-regressive or not
-            nb_features=None,  # 256,                   # number of random features, if not set, will default to
-                                                        # (d * log(d)), where d is the dimension of each head
-            feature_redraw_interval=1000,               # how frequently to redraw the projection matrix,
-                                                        # the more frequent, the slower the training
-            generalized_attention=False,                # defaults to softmax approximation,
-                                                        # but can be set to True for generalized attention
-            kernel_fn=self.params.activation(),         # the kernel function to be used,
-                                                        # if generalized attention is turned on, defaults to Relu
-            reversible=True,                            # reversible layers, from Reformer paper
-            ff_chunks=10,                               # chunk feedforward layer, from Reformer paper
-            use_scalenorm=False,                        # use scale norm, from 'Transformers without Tears' paper
-            use_rezero=False,                           # use rezero, from 'Rezero is all you need' paper
-            ff_glu=True,                                # use GLU variant for feedforward
-            ff_dropout=self.params.dropout,             # feedforward dropout
-            attn_dropout=self.params.dropout,           # post-attn dropout
-            local_attn_heads=self.params.heads // 2,    # 4 heads are local attention, 4 others are global performers
-            local_window_size=(patch_dim // self.params.heads) * 2    # window size of local attention
-        )
+            # Modules with Parameters
+            self.performer = Performer(
+                dim=self.embed_dim,                         # dimension
+                depth=self.params.attn_depth,               # layers
+                heads=self.params.heads,                    # heads
+                causal=True,                                # auto-regressive or not
+                nb_features=None,  # 256,                   # number of random features, if not set, will default to
+                # (d * log(d)), where d is the dimension of each head
+                feature_redraw_interval=1000,               # how frequently to redraw the projection matrix,
+                # the more frequent, the slower the training
+                generalized_attention=False,                # defaults to softmax approximation,
+                # but can be set to True for generalized attention
+                kernel_fn=self.params.activation(),         # the kernel function to be used,
+                # if generalized attention is turned on, defaults to Relu
+                reversible=True,                            # reversible layers, from Reformer paper
+                ff_chunks=10,                               # chunk feedforward layer, from Reformer paper
+                use_scalenorm=False,                        # use scale norm, from 'Transformers without Tears' paper
+                use_rezero=False,                           # use rezero, from 'Rezero is all you need' paper
+                ff_glu=True,                                # use GLU variant for feedforward
+                ff_dropout=self.params.dropout,             # feedforward dropout
+                attn_dropout=self.params.dropout,           # post-attn dropout
+                local_attn_heads=self.params.heads // 2,    # 4 heads are local attention, 4 others are global performers
+                local_window_size=(patch_dim // self.params.heads) * 2    # window size of local attention
+            )
 
-        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
-        self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
-            else F_x(self.embed_dim)
-        self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
-        self.dropout = nn.Dropout(self.params.dropout)
+            self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
+            self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
+                else F_x(self.embed_dim)
+            self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
+            self.dropout = nn.Dropout(self.params.dropout)
 
-        self.to_cls_token = nn.Identity()
+            self.to_cls_token = nn.Identity()
 
-        self.mlp_head = nn.Sequential(
-            nn.LayerNorm(self.embed_dim),
-            nn.Linear(self.embed_dim, self.params.lat_dim),
-            nn.GELU(),
-            nn.Dropout(self.params.dropout),
-            nn.Linear(self.params.lat_dim, n_classes),
-            nn.Softmax()
-        )
+            self.mlp_head = nn.Sequential(
+                nn.LayerNorm(self.embed_dim),
+                nn.Linear(self.embed_dim, self.params.lat_dim),
+                nn.GELU(),
+                nn.Dropout(self.params.dropout),
+                nn.Linear(self.params.lat_dim, n_classes),
+                nn.Softmax()
+            )
 
-    def forward(self, x):
-        """
-        :param x: the sequence to the encoder (required).
-        :return:
-        """
-        tensor = self.autopad(x)
-        p = self.params.patch_size
-        tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
+        def forward(self, x):
+            """
+            :param x: the sequence to the encoder (required).
+            :return:
+            """
+            tensor = self.autopad(x)
+            p = self.params.patch_size
+            tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
 
-        tensor = self.patch_to_embedding(tensor)
-        b, n, _ = tensor.shape
+            tensor = self.patch_to_embedding(tensor)
+            b, n, _ = tensor.shape
 
-        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
+            cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
 
-        tensor = torch.cat((cls_tokens, tensor), dim=1)
-        tensor += self.pos_embedding[:, :(n + 1)]
-        tensor = self.dropout(tensor)
+            tensor = torch.cat((cls_tokens, tensor), dim=1)
+            tensor += self.pos_embedding[:, :(n + 1)]
+            tensor = self.dropout(tensor)
 
-        tensor = self.performer(tensor)
+            tensor = self.performer(tensor)
 
-        tensor = self.to_cls_token(tensor[:, 0])
-        tensor = self.mlp_head(tensor)
-        return Namespace(main_out=tensor)
+            tensor = self.to_cls_token(tensor[:, 0])
+            tensor = self.mlp_head(tensor)
+            return Namespace(main_out=tensor)
 
-    def additional_scores(self, outputs):
-        return MultiClassScores(self)(outputs)
+        def additional_scores(self, outputs):
+            return MultiClassScores(self)(outputs)
+
+except ImportError:  # pragma: do not provide model class
+    print('You want to use `performer_pytorch` plugins which are not installed yet,'  # pragma: no-cover
+                        ' install it with `pip install performer_pytorch`.')