import inspect
from argparse import Namespace

import warnings

import torch
from performer_pytorch import Performer
from torch import nn

from einops import rearrange, repeat

from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape, F_x)
from util.module_mixins import CombinedModelMixins

MIN_NUM_PATCHES = 16


class VisualPerformer(CombinedModelMixins,
                      LightningBaseModule
                      ):

    def __init__(self, in_shape, n_classes, weight_init, activation,
                 embedding_size, heads, attn_depth, patch_size, use_residual,
                 use_bias, use_norm, dropout, lat_dim, loss, scheduler,
                 lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval):

        # TODO: Move this to parent class, or make it much easieer to access... But How...
        a = dict(locals())
        params = {arg: a[arg] for arg in inspect.signature(self.__init__).parameters.keys() if arg != 'self'}
        super(VisualPerformer, self).__init__(params)

        self.in_shape = in_shape
        assert len(self.in_shape) == 3, 'There need to be three Dimensions'
        channels, height, width = self.in_shape

        # Model Paramters
        # =============================================================================
        # Additional parameters
        self.embed_dim = self.params.embedding_size

        # Automatic Image Shaping
        self.patch_size = self.params.patch_size
        image_size = (max(height, width) // self.patch_size) * self.patch_size
        self.image_size = image_size + self.patch_size if image_size < max(height, width) else image_size

        # This should be obsolete
        assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'

        num_patches = (self.image_size // self.patch_size) ** 2
        patch_dim = channels * self.patch_size ** 2
        assert num_patches >= MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for ' + \
                                               f'attention. Try decreasing your patch size'

        # Correct the Embedding Dim
        if not self.embed_dim % self.params.heads == 0:
            self.embed_dim = (self.embed_dim // self.params.heads) * self.params.heads
            message = ('Embedding Dimension was fixed to be devideable by the number' +
                       f' of attention heads, is now: {self.embed_dim}')
            for func in print, warnings.warn:
                func(message)

        # Utility Modules
        self.autopad = AutoPadToShape((self.image_size, self.image_size))

        # Modules with Parameters
        self.performer = Performer(
            dim=self.embed_dim,                         # dimension
            depth=self.params.attn_depth,               # layers
            heads=self.params.heads,                    # heads
            causal=True,                                # auto-regressive or not
            nb_features=None,  # 256,                   # number of random features, if not set, will default to
                                                        # (d * log(d)), where d is the dimension of each head
            feature_redraw_interval=1000,               # how frequently to redraw the projection matrix,
                                                        # the more frequent, the slower the training
            generalized_attention=False,                # defaults to softmax approximation,
                                                        # but can be set to True for generalized attention
            kernel_fn=self.params.activation(),         # the kernel function to be used,
                                                        # if generalized attention is turned on, defaults to Relu
            reversible=True,                            # reversible layers, from Reformer paper
            ff_chunks=10,                               # chunk feedforward layer, from Reformer paper
            use_scalenorm=False,                        # use scale norm, from 'Transformers without Tears' paper
            use_rezero=False,                           # use rezero, from 'Rezero is all you need' paper
            ff_glu=True,                                # use GLU variant for feedforward
            ff_dropout=self.params.dropout,             # feedforward dropout
            attn_dropout=self.params.dropout,           # post-attn dropout
            local_attn_heads=self.params.heads // 2,    # 4 heads are local attention, 4 others are global performers
            local_window_size=(patch_dim // self.params.heads) * 2    # window size of local attention
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, self.embed_dim))
        self.patch_to_embedding = nn.Linear(patch_dim, self.embed_dim) if self.params.embedding_size \
            else F_x(self.embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))
        self.dropout = nn.Dropout(self.params.dropout)

        self.to_cls_token = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, self.params.lat_dim),
            nn.GELU(),
            nn.Dropout(self.params.dropout),
            nn.Linear(self.params.lat_dim, n_classes),
            nn.Softmax()
        )

    def forward(self, x):
        """
        :param x: the sequence to the encoder (required).
        :return:
        """
        tensor = self.autopad(x)
        p = self.params.patch_size
        tensor = rearrange(tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)

        tensor = self.patch_to_embedding(tensor)
        b, n, _ = tensor.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)

        tensor = torch.cat((cls_tokens, tensor), dim=1)
        tensor += self.pos_embedding[:, :(n + 1)]
        tensor = self.dropout(tensor)

        tensor = self.performer(tensor)

        tensor = self.to_cls_token(tensor[:, 0])
        tensor = self.mlp_head(tensor)
        return Namespace(main_out=tensor)

    def additional_scores(self, outputs):
        return MultiClassScores(self)(outputs)