from collections import defaultdict
from pathlib import Path

from abc import ABC

import torch
import pandas as pd
from matplotlib import pyplot as plt

from ml_lib.modules.util import LightningBaseModule
from util.loss_mixin import LossMixin
from util.optimizer_mixin import OptimizerMixin


class TrainMixin:

    def training_step(self, batch_xy, batch_nb, *args, **kwargs):
        assert isinstance(self, LightningBaseModule)
        batch_files, batch_x, batch_y = batch_xy
        y = self(batch_x).main_out
        if self.params.n_classes <= 2:
            loss = self.bce_loss(y.squeeze().float(), batch_y.float())
        else:
            if self.params.loss == 'focal_loss_rob':
                labels_one_hot = torch.nn.functional.one_hot(batch_y, num_classes=self.params.n_classes)
                loss = self.__getattribute__(self.params.loss)(y, labels_one_hot)
            else:
                loss = self.__getattribute__(self.params.loss)(y, batch_y.long())
        return dict(loss=loss)

    def training_epoch_end(self, outputs):
        assert isinstance(self, LightningBaseModule)
        keys = list(outputs[0].keys())

        summary_dict = {f'mean_{key}': torch.mean(torch.stack([output[key]
                                                               for output in outputs]))
                        for key in keys if 'loss' in key}
        summary_dict.update(epoch=self.current_epoch)
        self.log_dict(summary_dict)


class ValMixin:

    def validation_step(self, batch_xy, batch_idx, *args, **kwargs):
        assert isinstance(self, LightningBaseModule)
        batch_files, batch_x, batch_y = batch_xy
        model_out = self(batch_x)
        y = model_out.main_out

        sorted_y = defaultdict(list)
        sorted_batch_y = dict()
        for idx, file_name in enumerate(batch_files):
            sorted_y[file_name].append(y[idx])
            sorted_batch_y.update({file_name: batch_y[idx]})
        sorted_y = dict(sorted_y)

        for file_name in sorted_y:
            sorted_y.update({file_name: torch.stack(sorted_y[file_name])})


        target_y = torch.stack(tuple(sorted_batch_y.values())).long()
        if self.params.n_classes <= 2:
            mean_sorted_y = torch.stack([x.mean(dim=0) if x.shape[0] > 1 else x.squeeze(-1) for x in sorted_y.values()])
            self.metrics.update(mean_sorted_y, target_y)
        else:
            y_max = torch.stack(
                [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
            ).squeeze()
            y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
            self.metrics.update(y_one_hot, target_y)
        if self.params.n_classes <= 2:
            val_loss = self.bce_loss(y.squeeze().float(), batch_y.float())
        else:
            val_loss = self.ce_loss(y, batch_y.long())

        return dict(batch_files=batch_files, val_loss=val_loss,
                    batch_idx=batch_idx, y=y, batch_y=batch_y)

    def validation_epoch_end(self, outputs, *_, **__):
        assert isinstance(self, LightningBaseModule)
        summary_dict = dict()
        keys = list(outputs[0].keys())
        summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
                                                                    for output in outputs]))
                             for key in keys if 'loss' in key}
                            )

        sorted_y = defaultdict(list)
        sorted_batch_y = dict()
        for output in outputs:
            for idx, file_name in enumerate(output['batch_files']):
                sorted_y[file_name].append(output['y'][idx])
                sorted_batch_y.update({file_name: output['batch_y'][idx]})
        sorted_y = dict(sorted_y)
        sorted_batch_y = torch.stack(tuple(sorted_batch_y.values())).long()

        for file_name in sorted_y:
            sorted_y.update({file_name: torch.stack(sorted_y[file_name])})

        if self.params.n_classes <= 2:
            mean_sorted_y = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
            mean_sorted_y = torch.stack(mean_sorted_y).squeeze(1)
            # mean_sorted_y = mean_sorted_y if mean_sorted_y.numel() > 1 else mean_sorted_y.unsqueeze(-1)
            max_vote_loss = self.bce_loss(mean_sorted_y.float(), sorted_batch_y.float())

            # Sklearn Scores
            additional_scores = self.additional_scores(dict(y=mean_sorted_y, batch_y=sorted_batch_y))

        else:
            y_max = torch.stack(
                [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
            ).squeeze()
            y_one_hot = torch.nn.functional.one_hot(y_max, num_classes=self.params.n_classes).float()
            max_vote_loss = self.ce_loss(y_one_hot, sorted_batch_y)
            # Sklearn Scores
            additional_scores = self.additional_scores(dict(y=y_one_hot, batch_y=sorted_batch_y))

        summary_dict.update(val_max_vote_loss=max_vote_loss, **additional_scores)

        summary_dict.update({f'mean_{key}': torch.mean(torch.stack([output[key]
                                                                    for output in outputs]))
                             for key in keys if 'loss' in key}
                            )

        pl_metrics, pl_images = self.metrics.compute_and_prepare()
        self.metrics.reset()
        summary_dict.update(**pl_metrics)
        summary_dict.update(epoch=self.current_epoch)

        self.log_dict(summary_dict)
        # For Debugging:
        # print(f'Summary Metrics are: {summary_dict}')

        for name, image in pl_images.items():
            self.logger.log_image(name, image, step=self.global_step)
            plt.close(image)
        pass


class TestMixin:

    def test_step(self, batch_xy, batch_idx, *_, **__):
        assert isinstance(self, LightningBaseModule)
        batch_files, batch_x, batch_y = batch_xy
        model_out = self(batch_x)
        y = model_out.main_out
        return dict(batch_files=batch_files, batch_idx=batch_idx, y=y)

    def test_epoch_end(self, outputs, *_, **__):
        assert isinstance(self, LightningBaseModule)
        # No logging, just inference.

        sorted_y = defaultdict(list)
        for output in outputs:
            for idx, file_name in enumerate(output['batch_files']):
                sorted_y[file_name].append(output['y'][idx].cpu())
        sorted_y = dict(sorted_y)

        for file_name in sorted_y:
            sorted_y.update({file_name: torch.stack(sorted_y[file_name])})


        if self.params.n_classes > 2:
            pred = torch.stack(
                [torch.argmax(x.mean(dim=0)) if x.shape[0] > 1 else torch.argmax(x) for x in sorted_y.values()]
            ).squeeze().cpu()
            class_names = {val: key for val, key in
                           enumerate(['background', 'chimpanze', 'geunon', 'mandrille', 'redcap'])}
        else:
            pred = [x.mean(dim=0) if x.shape[0] > 1 else x.squeeze().unsqueeze(-1) for x in sorted_y.values()]
            pred = torch.stack(pred).squeeze()
            pred = torch.where(pred > 0.5, 1, 0)
            class_names = {val: key for val, key in enumerate(['negative', 'positive'])}


        df = pd.DataFrame(data=dict(filename=[Path(x).name.replace('.npy', '.wav') for x in sorted_y.keys()],
                                    prediction=[class_names[x.item()] for x in pred.cpu()]))
        result_file = Path(self.logger.log_dir / 'predictions.csv')
        if result_file.exists():
            try:
                result_file.unlink()
            except:
                print('File already existed')
                pass
        with result_file.open(mode='wb') as csv_file:
            df.to_csv(index=False, path_or_buf=csv_file)
        if False:
            with result_file.open(mode='rb') as csv_file:
                try:
                    self.logger.neptunelogger.log_artifact(csv_file)
                except:
                    print('No possible to send to neptune')
                    pass


class CombinedModelMixins(LossMixin,
                          TrainMixin,
                          ValMixin,
                          TestMixin,
                          OptimizerMixin,
                          LightningBaseModule,
                          ABC):
    pass