from collections import defaultdict
from itertools import cycle

from abc import ABC
from argparse import Namespace

import torch

import numpy as np
from numpy import interp

from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay, f1_score, roc_auc_score
import matplotlib.pyplot as plt

from torch import nn
from torch.optim import Adam
from torch_geometric.data import Data, DataLoader

from torchcontrib.optim import SWA


from ml_lib.modules.util import LightningBaseModule
from ml_lib.utils.tools import to_one_hot

from .project_settings import GlobalVar


class BaseOptimizerMixin:

    def configure_optimizers(self):
        assert isinstance(self, LightningBaseModule)
        opt = Adam(params=self.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
        if self.params.sto_weight_avg:
            # TODO: Make this glabaly available.
            opt = SWA(opt, swa_start=10, swa_freq=5, swa_lr=0.05)
        return opt

    def on_train_end(self):
        assert isinstance(self, LightningBaseModule)
        for opt in self.trainer.optimizers:
            if isinstance(opt, SWA):
                opt.swap_swa_sgd()

    def on_epoch_end(self):
        assert isinstance(self, LightningBaseModule)
        if self.params.opt_reset_interval:
            if self.current_epoch % self.params.opt_reset_interval == 0:
                for opt in self.trainer.optimizers:
                    opt.state = defaultdict(dict)


class BaseTrainMixin:

    # Absolute Error
    absolute_loss = nn.L1Loss()
    # negative Log Likelyhood
    nll_loss = nn.NLLLoss()
    # Binary Cross Entropy
    bce_loss = nn.BCELoss()

    def training_step(self, batch_norm_pos_y, batch_nb, *_, **__):
        assert isinstance(self, LightningBaseModule)
        data = self.batch_to_data(batch_norm_pos_y) if not isinstance(batch_norm_pos_y, Data) else batch_norm_pos_y
        y = self(data).main_out
        nll_loss = self.nll_loss(y, data.y)
        return dict(loss=nll_loss, log=dict(batch_nb=batch_nb))

    def training_epoch_end(self, outputs):
        assert isinstance(self, LightningBaseModule)
        keys = list(outputs[0].keys())

        summary_dict = dict(log={f'mean_{key}': torch.mean(torch.stack([output[key]
                                                                        for output in outputs]))
                                 for key in keys if 'loss' in key})
        return summary_dict


class BaseValMixin:

    # Absolute Error
    absolute_loss = nn.L1Loss()
    # negative Log Likelyhood
    nll_loss = nn.NLLLoss()
    # Binary Cross Entropy
    bce_loss = nn.BCELoss()

    def validation_step(self, batch_pos_x_n_y_c, batch_idx, *_, **__):
        assert isinstance(self, LightningBaseModule)
        data = self.batch_to_data(batch_pos_x_n_y_c) if not isinstance(batch_pos_x_n_y_c, Data) else batch_pos_x_n_y_c
        y = self(data).main_out
        nll_loss = self.nll_loss(y, data.y)
        return dict(val_nll_loss=nll_loss,
                    batch_idx=batch_idx, y=y, batch_y=data.y)

    def validation_epoch_end(self, outputs, *_, **__):
        assert isinstance(self, LightningBaseModule)
        summary_dict = dict(log=dict())
        # In case of Multiple given dataloader this will outputs will be: list[list[dict[]]]
        # for output_idx, output in enumerate(outputs):
        # else:list[dict[]]
        keys = list(outputs[0].keys())
        # Add Every Value das has a "loss" in it, by calc. mean over all occurences.
        summary_dict['log'].update({f'mean_{key}': torch.mean(torch.stack([output[key]
                                                                           for output in outputs]))
                                    for key in keys if 'loss' in key}
                                   )

        #######################################################################################
        # Additional Score  -  UAR  -  ROC  -  Conf. Matrix  -  F1
        #######################################################################################
        #
        # INIT
        y_true = torch.cat([output['batch_y'] for output in outputs]).cpu().numpy()
        y_true_one_hot = to_one_hot(y_true, self.n_classes)

        y_pred = torch.cat([output['y'] for output in outputs]).squeeze().cpu().float().numpy()
        y_pred_max = np.argmax(y_pred, axis=1)

        class_names = {val: key for key, val in GlobalVar.classes.items()}
        ######################################################################################
        #
        # F1 SCORE
        micro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='micro', sample_weight=None,
                                  zero_division=True)
        macro_f1_score = f1_score(y_true, y_pred_max, labels=None, pos_label=1, average='macro', sample_weight=None,
                                  zero_division=True)
        summary_dict['log'].update(dict(micro_f1_score=micro_f1_score, macro_f1_score=macro_f1_score))

        #######################################################################################
        #
        # ROC Curve

        # Compute ROC curve and ROC area for each class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for i in range(self.n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_true_one_hot[:, i], y_pred[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])

        # Compute micro-average ROC curve and ROC area
        fpr["micro"], tpr["micro"], _ = roc_curve(y_true_one_hot.ravel(), y_pred.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

        # First aggregate all false positive rates
        all_fpr = np.unique(np.concatenate([fpr[i] for i in range(self.n_classes)]))

        # Then interpolate all ROC curves at this points
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(self.n_classes):
            mean_tpr += interp(all_fpr, fpr[i], tpr[i])

        # Finally average it and compute AUC
        mean_tpr /= self.n_classes

        fpr["macro"] = all_fpr
        tpr["macro"] = mean_tpr
        roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

        # Plot all ROC curves
        plt.figure()
        plt.plot(fpr["micro"], tpr["micro"],
                 label=f'micro ROC ({round(roc_auc["micro"], 2)})',
                 color='deeppink', linestyle=':', linewidth=4)

        plt.plot(fpr["macro"], tpr["macro"],
                 label=f'macro ROC({round(roc_auc["macro"], 2)})',
                 color='navy', linestyle=':', linewidth=4)

        colors = cycle(['firebrick', 'orangered', 'gold', 'olive', 'limegreen', 'aqua',
                        'dodgerblue', 'slategrey', 'royalblue', 'indigo', 'fuchsia'], )

        for i, color in zip(range(self.n_classes), colors):
            plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'{class_names[i]} ({round(roc_auc[i],2 )})')

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.legend(loc="lower right")

        self.logger.log_image('ROC', image=plt.gcf(), step=self.current_epoch)
        plt.clf()

        #######################################################################################
        #
        # ROC SCORE

        try:
            macro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
                                              average="macro")
            summary_dict['log'].update(macro_roc_auc_ovr=macro_roc_auc_ovr)
        except ValueError:
            micro_roc_auc_ovr = roc_auc_score(y_true_one_hot, y_pred, multi_class="ovr",
                                              average="micro")
            summary_dict['log'].update(micro_roc_auc_ovr=micro_roc_auc_ovr)

        #######################################################################################
        #
        # Confusion matrix

        cm = confusion_matrix([class_names[x] for x in y_true], [class_names[x] for x in y_pred_max],
                              labels=[class_names[key] for key in class_names.keys()],
                              normalize='all')
        disp = ConfusionMatrixDisplay(confusion_matrix=cm)
        disp.plot(include_values=True)
        self.logger.log_image('Confusion Matrix', image=disp.figure_, step=self.current_epoch)

        plt.close('all')

        return summary_dict


class DatasetMixin:

    def build_dataset(self, dataset_class, **kwargs):
        assert isinstance(self, LightningBaseModule)
        assert dataset_class.name == self.params.dataset_type, f'Check the dataset! ' + \
                                                               f'Expected was {self.params.dataset_type}, ' + \
                                                               f'given:{dataset_class.name}'

        # Dataset
        # =============================================================================
        # Data Augmentations or Utility Transformations

        # Dataset
        dataset = Namespace(
            **dict(
                # TRAIN DATASET
                train_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.train,
                                            **kwargs),

                # VALIDATION DATASET
                val_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.devel,
                                          **kwargs),

                # TEST DATASET
                test_dataset=dataset_class(self.params.root, mode=GlobalVar.data_split.predict,
                                           **kwargs),
            )
        )
        return dataset


class BaseDataloadersMixin(ABC):

    # Dataloaders
    # ================================================================================
    # Train Dataloader
    def train_dataloader(self):
        assert isinstance(self, LightningBaseModule)
        # In case you want to implement bootstraping
        # sampler = RandomSampler(self.dataset.train_dataset, True, len(self.dataset.train_dataset))
        sampler = None
        return DataLoader(dataset=self.dataset.train_dataset, shuffle=False if not sampler else None, sampler=sampler,
                          batch_size=self.params.batch_size,
                          num_workers=self.params.worker)

    # Test Dataloader
    def test_dataloader(self):
        assert isinstance(self, LightningBaseModule)
        return DataLoader(dataset=self.dataset.test_dataset, shuffle=False,
                          batch_size=self.params.batch_size,
                          num_workers=self.params.worker)

    # Validation Dataloader
    def val_dataloader(self):
        assert isinstance(self, LightningBaseModule)
        val_dataloader = DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
                                    batch_size=self.params.batch_size, num_workers=self.params.worker)
        # Alternative return [val_dataloader, alternative dataloader], there will be a dataloader_idx in validation_step
        return val_dataloader