from torch.optim import Adam
from torch.nn.functional import mse_loss
from networks.modules import *
import torch


class SeperatingAAE(Module):

    def __init__(self, latent_dim, features, train_on_predictions=False, use_norm=True):
        super(SeperatingAAE, self).__init__()

        self.latent_dim = latent_dim
        self.features = features
        self.train_on_predictions = train_on_predictions
        self.spatial_encoder = PoolingEncoder(self.latent_dim, use_norm=use_norm)
        self.temporal_encoder = Encoder(self.latent_dim, use_dense=False, use_norm=use_norm)
        self.decoder = Decoder(self.latent_dim * 2, self.features, use_norm=use_norm)
        self.spatial_discriminator = Discriminator(self.latent_dim, self.features)
        self.temporal_discriminator = Discriminator(self.latent_dim, self.features)

    def forward(self, batch):
        # Encoder
        #  outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
        z_spatial, z_temporal = self.spatial_encoder(batch), self.temporal_encoder(batch)
        # Decoder
        # First repeat the data accordingly to the batch size
        z_concat = torch.cat((z_spatial, z_temporal), dim=-1)
        z_repeatet = Repeater((batch.shape[0], batch.shape[1], -1))(z_concat)
        x_hat = self.decoder(z_repeatet)
        return z_spatial, z_temporal, x_hat

    def training_step(self, batch, _, optimizer_i):
        x, y = batch
        spatial_latent_fake, temporal_latent_fake, x_hat = self.network.forward(x)
        if optimizer_i == 0:
            # ---------------------
            #  Train temporal Discriminator
            # ---------------------
            # latent_fake, reconstruction
            temporal_latent_real = self.normal.sample(temporal_latent_fake.shape).to(device)

            # Evaluate the input
            temporal_real_prediction = self.network.temporal_discriminator.forward(temporal_latent_real)
            temporal_fake_prediction = self.network.temporal_discriminator.forward(temporal_latent_fake)

            # Train the discriminator
            temporal_loss_real = mse_loss(temporal_real_prediction,
                                          torch.zeros(temporal_real_prediction.shape, device=device))
            temporal_loss_fake = mse_loss(temporal_fake_prediction,
                                          torch.ones(temporal_fake_prediction.shape, device=device))

            # Calculate the mean over bot the real and the fake acc
            # ToDo: do i need to compute this seperate?
            d_loss = 0.5 * torch.add(temporal_loss_real, temporal_loss_fake)  * 0.001
            return {'loss': d_loss}

        if optimizer_i == 1:
            # ---------------------
            #  Train spatial Discriminator
            # ---------------------
            # latent_fake, reconstruction
            spatial_latent_real = self.normal.sample(spatial_latent_fake.shape).to(device)

            # Evaluate the input
            spatial_real_prediction = self.network.spatial_discriminator.forward(spatial_latent_real)
            spatial_fake_prediction = self.network.spatial_discriminator.forward(spatial_latent_fake)

            # Train the discriminator
            spatial_loss_real = mse_loss(spatial_real_prediction,
                                         torch.zeros(spatial_real_prediction.shape, device=device))
            spatial_loss_fake = mse_loss(spatial_fake_prediction,
                                         torch.ones(spatial_fake_prediction.shape, device=device))

            # Calculate the mean over bot the real and the fake acc
            # ToDo: do i need to compute this seperate?
            d_loss = 0.5 * torch.add(spatial_loss_real, spatial_loss_fake) * 0.001
            return {'loss': d_loss}

        elif optimizer_i == 2:
            # ---------------------
            #  Train AutoEncoder
            # ---------------------
            loss = mse_loss(y, x_hat) if self.train_on_predictions else mse_loss(x, x_hat)
            return {'loss': loss}

        else:
            raise RuntimeError('This should not have happened, catch me if u can.')

    #FixMe: This is Fucked up, why do i need to put an additional empty list here?
    def configure_optimizers(self):
        return [Adam([*self.network.spatial_discriminator.parameters(), *self.network.spatial_encoder.parameters()]
                     , lr=0.02),
                Adam([*self.network.temporal_discriminator.parameters(), *self.network.temporal_encoder.parameters()]
                     , lr=0.02),
                Adam([*self.network.temporal_encoder.parameters(),
                      *self.network.spatial_encoder.parameters(),
                      *self.network.decoder.parameters()]
                     , lr=0.02)], []


if __name__ == '__main__':
    raise PermissionError('Get out of here - never run this module')