import numpy as np
import torch
import torch.nn as nn
import torch.functional as F


class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.to = args

    def forward(self, x):
        return x.view(x.shape[0], *self.to)

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], -1)


class AE(nn.Module):
    def __init__(self, in_dim=400):
        super(AE, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 8),
            nn.ReLU(),
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, in_dim),
            nn.ReLU()
        )

    def forward(self, data):
        x = data.view(data.shape[0], -1)
        return self.net(x), x

    def train_loss(self, data):
        criterion = nn.MSELoss()
        y_hat, y = self.forward(data)
        loss = criterion(y_hat, y)
        return loss

    def test_loss(self, data):
        y_hat, y = self.forward(data)
        preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
        return preds

    def init_weights(self):
        def _weight_init(m):
            if hasattr(m, 'weight'):
                if isinstance(m.weight, torch.Tensor):
                    torch.nn.init.xavier_uniform_(m.weight,
                                                  gain=nn.init.calculate_gain('relu'))
            if hasattr(m, 'bias'):
                if isinstance(m.bias, torch.Tensor):
                    m.bias.data.fill_(0.01)

        self.apply(_weight_init)



class SubSpecCAE(nn.Module):
    def __init__(self, F=20, T=80, norm='batch',
                 activation='relu', dropout_prob=0.25):
        super(SubSpecCAE, self).__init__()
        self.T = T
        self.F = F
        self.activation = activation
        self.loss_fn = nn.MSELoss()
        Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
        Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
        a, b = 20, 40
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=a, kernel_size=7, stride=1, padding=3),  # 32 x 20 x 80
            Norm(a),
            Activation(),
            nn.MaxPool2d((F//10, 5)),
            nn.Dropout(dropout_prob),
            nn.Conv2d(in_channels=a, out_channels=b, kernel_size=7, stride=1, padding=3),  # 64 x 10 x 16
            Norm(b),
            Activation(),
            nn.MaxPool2d(4, T),
            nn.Dropout(dropout_prob),
            Flatten(),
            nn.Linear(b, 16)
        )

        self.decoder = nn.Sequential(
            nn.Linear(16, b),
            Reshape(b, 1, 1),
            nn.Upsample(size=(10, 16), mode='bilinear', align_corners=False),
            nn.ConvTranspose2d(in_channels=b, out_channels=a, kernel_size=7, stride=1, padding=3),
            Norm(a),
            Activation(),
            nn.Upsample(size=(20, 80), mode='bilinear', align_corners=False),
            nn.Dropout(dropout_prob),
            nn.ConvTranspose2d(in_channels=a, out_channels=1, kernel_size=7, stride=1, padding=3),
            nn.Sigmoid()
        )

    def forward(self, sub_x):
        #x = x[:, self.band, :,].unsqueeze(1)  # select a single supspec
        encoded = self.encoder(sub_x)
        decoded = self.decoder(encoded)
        return decoded, sub_x

    def init_weights(self):
        def weight_init(m):
            if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    m.bias.data.fill_(0.2)
        self.apply(weight_init)


class FullSubSpecCAE(nn.Module):
    def __init__(self, F=20, T=80,
                 norm='batch', activation='relu',
                 dropout_prob=0.25, weight_sharing=False,
                 sub_bands=[0, 1, 2, 3, 4, 5, 6]):
        super(FullSubSpecCAE, self).__init__()
        self.bands = sub_bands
        self.weight_sharing = weight_sharing
        self.aes = nn.ModuleList([
            SubSpecCAE(F, T, norm, activation, dropout_prob) for band in range(
                1 if weight_sharing else len(sub_bands)
            )
        ])
        for ae in self.aes: ae.init_weights()
        self.loss_fn = nn.MSELoss()

    def select_sub_ae(self, band):
        if self.weight_sharing:
            return self.aes[0]
        else:
            return self.aes[band]

    def select_sub_band(self, x, band):
        return x[:, band, :, ].unsqueeze(1)

    def forward(self, x):
        y_hat, y = [], []
        for band in self.bands:
            sub_ae = self.select_sub_ae(band)
            sub_x = self.select_sub_band(x, band)
            decoded, target = sub_ae(sub_x)
            y.append(target)
            y_hat.append(decoded)
        decoded = torch.cat(y_hat, dim=1)
        y = torch.cat(y, dim=1)
        return decoded, y

    def train_loss(self, data):
        y_hat, y = self.forward(data)  # torch.Size([96, 7, 20, 80])
        loss = self.complete_mse(y_hat, y)
        return loss

    def test_loss(self, data):
        y_hat, y = self.forward(data)
        preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
        return preds

    def sub_band_mse(self, y_hat, y):
        losses = []
        for band in self.bands:
            sub_y = self.select_sub_band(y, band)
            sub_y_hat = self.select_sub_band(y_hat, band)
            sub_loss = torch.mean((sub_y_hat - sub_y) ** 2, dim=tuple(range(1, sub_y.dim())))  # torch.Size([96])
            losses.append(sub_loss)
        losses = torch.stack(losses, dim=1)  # torch.Size([96, 7])
        return losses

    def complete_mse(self, y_hat, y):
        return self.sub_band_mse(y_hat, y).mean(dim=0).sum()

    def gather_predictions(self, dataloader):
        device = next(self.parameters()).device
        predictions = []
        labels = []
        self.eval()
        with torch.no_grad():
            for batch in dataloader:
                data, l = batch
                data = data.to(device)
                y_hat, y = self.forward(data)
                mse = self.sub_band_mse(y_hat, y)  # 96 x 7
                predictions.append(mse)
                labels += l.tolist()
        predictions = torch.cat(predictions).cpu().numpy()
        self.train()
        return predictions, labels



class GlobalCAE(nn.Module):
    def __init__(self, F=20, T=80, norm='batch', activation='relu', dropout_prob=0.25):
        super(GlobalCAE, self).__init__()
        self.activation = activation
        Norm = nn.BatchNorm2d if norm == 'batch' else nn.InstanceNorm2d
        Activation = nn.ReLU if activation == 'relu' else nn.LeakyReLU
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=20, kernel_size=8, stride=2),  # 32 x 20 x 80
            Norm(20),
            Activation(),
            nn.Dropout(dropout_prob),
            nn.Conv2d(in_channels=20, out_channels=40, kernel_size=5, stride=2),  # 64 x 10 x 16
            Norm(40),
            Activation(),
            nn.Dropout(dropout_prob),
            nn.Conv2d(in_channels=40, out_channels=60, kernel_size=3, stride=2),  # 64 x 10 x 16
            Norm(60),
            Flatten(),
            nn.Linear(60*8*8, 64)

        )
        self.decoder = nn.Sequential(
            nn.Linear(64, 60*8*8),
            Reshape(60, 8, 8),
            nn.ConvTranspose2d(in_channels=60, out_channels=40, kernel_size=3, stride=2),
            Norm(40),
            Activation(),
            nn.Dropout(dropout_prob),
            nn.ConvTranspose2d(in_channels=40, out_channels=20, kernel_size=5, stride=2),
            Norm(20),
            Activation(),
            nn.Dropout(dropout_prob),
            nn.ConvTranspose2d(in_channels=20, out_channels=1, kernel_size=8, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, x

    def train_loss(self, data):
        criterion = nn.MSELoss()
        y_hat, y = self.forward(data)
        loss = criterion(y_hat, y)
        return loss

    def test_loss(self, data):
        y_hat, y = self.forward(data)
        preds = torch.sum((y_hat - y) ** 2, dim=tuple(range(1, y_hat.dim())))
        return preds

    def init_weights(self):
        def weight_init(m):
            if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.Linear):
                torch.nn.init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    m.bias.data.fill_(0.01)
        self.apply(weight_init)