from statistics import mean

from random import choice

import torch
from functools import reduce
from operator import mul

from torch import nn
from torch.optim import Adam

from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten

import matplotlib.pyplot as plt


class CNNRouteGeneratorModel(LightningBaseModule):

    name = 'CNNRouteGenerator'

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.hparams.train_param.lr)

    def training_step(self, batch_xy, batch_nb, *args, **kwargs):
        batch_x, alternative = batch_xy
        generated_alternative, z, mu, logvar = self(batch_x)
        element_wise_loss = self.criterion(generated_alternative, alternative)
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)

        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # Dimensional Resizing TODO: Does This make sense? Sanity Check it!
        # kld_loss /= reduce(mul, self.in_shape)

        loss = (kld_loss + element_wise_loss) / 2
        return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))

    def _test_val_step(self, batch_xy, batch_nb, *args):
        batch_x, _ = batch_xy
        map_array, trajectory, label = batch_x

        generated_alternative, z, mu, logvar = self(batch_x)

        return dict(batch_nb=batch_nb,  label=label, generated_alternative=generated_alternative, pred_label=-1)

    def _test_val_epoch_end(self, outputs, test=False):
        maps, trajectories, labels, val_restul_dict = self.generate_random()

        from lib.visualization.generator_eval import GeneratorVisualizer
        g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
        fig = g.draw()
        self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

        return dict(epoch=self.current_epoch)

    def validation_step(self, *args):
        return self._test_val_step(*args)

    def validation_epoch_end(self, outputs: list):
        return self._test_val_epoch_end(outputs)

    def test_step(self, *args):
        return self._test_val_step(*args)

    def test_epoch_end(self, outputs):
        return self._test_val_epoch_end(outputs, test=True)

    def __init__(self, *params, issubclassed=False):
        super(CNNRouteGeneratorModel, self).__init__(*params)

        if not issubclassed:
            # Dataset
            self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
                                    length=self.hparams.data_param.dataset_length, normalized=True)
            self.criterion = nn.MSELoss()

        # Additional Attributes
        self.in_shape = self.dataset.map_shapes_max
        # Todo: Better naming and size in Parameters
        self.feature_dim = self.hparams.model_param.lat_dim * 10
        self.feature_mixed_dim = self.feature_dim + self.feature_dim + 1

        # NN Nodes
        ###################################################
        #
        # Utils
        self.activation = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

        #
        # Map Encoder
        self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
                                     conv_filters=self.hparams.model_param.filters[0],
                                     use_norm=self.hparams.model_param.use_norm,
                                     use_bias=self.hparams.model_param.use_bias)

        self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
                                        conv_padding=1, conv_filters=self.hparams.model_param.filters[0],
                                        use_norm=self.hparams.model_param.use_norm,
                                        use_bias=self.hparams.model_param.use_bias)
        self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
                                     conv_filters=self.hparams.model_param.filters[1],
                                     use_norm=self.hparams.model_param.use_norm,
                                     use_bias=self.hparams.model_param.use_bias)

        self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
                                        conv_padding=1, conv_filters=self.hparams.model_param.filters[1],
                                        use_norm=self.hparams.model_param.use_norm,
                                        use_bias=self.hparams.model_param.use_bias)
        self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
                                     conv_filters=self.hparams.model_param.filters[2],
                                     use_norm=self.hparams.model_param.use_norm,
                                     use_bias=self.hparams.model_param.use_bias)

        self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
                                        conv_padding=1, conv_filters=self.hparams.model_param.filters[2],
                                        use_norm=self.hparams.model_param.use_norm,
                                        use_bias=self.hparams.model_param.use_bias)
        self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
                                     conv_filters=self.hparams.model_param.filters[2]*2,
                                     use_norm=self.hparams.model_param.use_norm,
                                     use_bias=self.hparams.model_param.use_bias)

        self.map_flat = Flatten(self.map_conv_3.shape)
        self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)

        #
        # Trajectory Encoder
        self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
                                      conv_filters=self.hparams.model_param.filters[0],
                                      use_norm=self.hparams.model_param.use_norm,
                                      use_bias=self.hparams.model_param.use_bias)

        self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
                                      conv_filters=self.hparams.model_param.filters[0],
                                      use_norm=self.hparams.model_param.use_norm,
                                      use_bias=self.hparams.model_param.use_bias)

        self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
                                      conv_filters=self.hparams.model_param.filters[0],
                                      use_norm=self.hparams.model_param.use_norm,
                                      use_bias=self.hparams.model_param.use_bias)

        self.traj_flat = Flatten(self.traj_conv_3.shape)
        self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)

        #
        # Mixed Encoder
        self.mixed_lin = nn.Linear(self.feature_mixed_dim, self.feature_mixed_dim)
        self.mixed_norm = nn.BatchNorm1d(self.feature_mixed_dim) if self.hparams.model_param.use_norm else lambda x: x

        #
        # Variational Bottleneck
        self.mu = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)
        self.logvar = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)

        #
        # Alternative Generator
        self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
        self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape))

        self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)

        self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
                                         conv_padding=0, conv_kernel=5, conv_stride=1,
                                         use_norm=self.hparams.model_param.use_norm)
        self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
                                         conv_padding=0, conv_kernel=3, conv_stride=1,
                                         use_norm=self.hparams.model_param.use_norm)
        self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
                                         conv_padding=1, conv_kernel=3, conv_stride=1,
                                         use_norm=self.hparams.model_param.use_norm)
        self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
                                           conv_padding=1, conv_kernel=3, conv_stride=1,
                                           use_norm=self.hparams.model_param.use_norm)

    def forward(self, batch_x):
        #
        # Sorting the Input
        map_array, trajectory, label = batch_x

        #
        # Encode
        z, mu, logvar = self.encode(map_array, trajectory, label)

        #
        # Generate
        alt_tensor = self.generate(z)
        return alt_tensor, z, mu, logvar

    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def generate(self, z):
        alt_tensor = self.alt_lin_1(z)
        alt_tensor = self.activation(alt_tensor)
        alt_tensor = self.alt_lin_2(alt_tensor)
        alt_tensor = self.activation(alt_tensor)
        alt_tensor = self.reshape_to_map(alt_tensor)
        alt_tensor = self.alt_deconv_1(alt_tensor)
        alt_tensor = self.alt_deconv_2(alt_tensor)
        alt_tensor = self.alt_deconv_3(alt_tensor)
        alt_tensor = self.alt_deconv_out(alt_tensor)
        # alt_tensor = self.activation(alt_tensor)
        alt_tensor = self.sigmoid(alt_tensor)
        return alt_tensor

    def encode(self, map_array, trajectory, label):
        map_tensor = self.map_conv_0(map_array)
        map_tensor = self.map_res_1(map_tensor)
        map_tensor = self.map_conv_1(map_tensor)
        map_tensor = self.map_res_2(map_tensor)
        map_tensor = self.map_conv_2(map_tensor)
        map_tensor = self.map_res_3(map_tensor)
        map_tensor = self.map_conv_3(map_tensor)
        map_tensor = self.map_flat(map_tensor)
        map_tensor = self.map_lin(map_tensor)

        traj_tensor = self.traj_conv_1(trajectory)
        traj_tensor = self.traj_conv_2(traj_tensor)
        traj_tensor = self.traj_conv_3(traj_tensor)
        traj_tensor = self.traj_flat(traj_tensor)
        traj_tensor = self.traj_lin(traj_tensor)

        mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
        mixed_tensor = self.mixed_norm(mixed_tensor)
        mixed_tensor = self.activation(mixed_tensor)
        mixed_tensor = self.mixed_lin(mixed_tensor)
        mixed_tensor = self.mixed_norm(mixed_tensor)
        mixed_tensor = self.activation(mixed_tensor)

        #
        # Parameter and Sampling
        mu = self.mu(mixed_tensor)
        logvar = self.logvar(mixed_tensor)
        # logvar = torch.clamp(logvar, min=0, max=10)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def generate_random(self, n=6):
        maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]

        trajectories = [x.get_random_trajectory() for x in maps]
        trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
        trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
        trajectories = self._move_to_model_device(torch.stack(trajectories))

        maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
        maps = self._move_to_model_device(torch.stack(maps))

        labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
        return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)


class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):

    name = 'CNNRouteGeneratorDiscriminated'

    def training_step(self, batch_xy, batch_nb, *args, **kwargs):
        batch_x, label = batch_xy

        generated_alternative, z, mu, logvar = self(batch_x)
        map_array, trajectory = batch_x

        map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
        pred_label = self.discriminator(map_stack)
        discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # Dimensional Resizing
        kld_loss /= reduce(mul, self.in_shape)

        loss = (kld_loss + discriminated_bce_loss) / 2
        return dict(loss=loss, log=dict(loss=loss,
                                        discriminated_bce_loss=discriminated_bce_loss,
                                        kld_loss=kld_loss)
                    )

    def _test_val_step(self, batch_xy, batch_nb, *args):
        batch_x, label = batch_xy

        generated_alternative, z, mu, logvar = self(batch_x)
        map_array, trajectory = batch_x

        map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
        pred_label = self.discriminator(map_stack)

        discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
        return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
                    pred_label=pred_label, label=label, generated_alternative=generated_alternative)

    def validation_step(self, *args):
        return self._test_val_step(*args)

    def validation_epoch_end(self, outputs: list):
        return self._test_val_epoch_end(outputs)

    def _test_val_epoch_end(self, outputs, test=False):
        evaluation = ROCEvaluation(plot_roc=True)
        pred_label = torch.cat([x['pred_label'] for x in outputs])
        labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
        mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()

        # Sci-py call ROC eval call is eval(true_label, prediction)
        roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
        if test:
            # self.logger.log_metrics(score_dict)
            self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
        plt.clf()
        maps, trajectories, labels, val_restul_dict = self.generate_random()

        from lib.visualization.generator_eval import GeneratorVisualizer
        g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
        fig = g.draw()
        self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)

        return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)

    def test_step(self, *args):
        return self._test_val_step(*args)

    def test_epoch_end(self, outputs):
        return self._test_val_epoch_end(outputs, test=True)

    @property
    def discriminator(self):
        if self._disc is None:
            raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
        return self._disc

    def set_discriminator(self, disc_model):
        if self._disc is not None:
            raise RuntimeError('Discriminator has already been set... What are trying to do?')
        self._disc = disc_model

    def __init__(self, *params):
        super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)

        self._disc = None

        self.criterion = nn.BCELoss()

        self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
                                length=self.hparams.data_param.dataset_length, normalized=True)