import torch
import pytorch_lightning as pl
from torch.nn import Module

from abc import ABC, abstractmethod


######################
# Abstract Network class following the Lightning Syntax
class LightningModule(pl.LightningModule, ABC):

    def __init__(self):
        super(LightningModule, self).__init__()

    @abstractmethod
    def forward(self, x):
        raise NotImplementedError

    @abstractmethod
    def training_step(self, batch, batch_nb):
        # REQUIRED
        raise NotImplementedError

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        pass

    def validation_end(self, outputs):
        # OPTIONAL
        pass

    @abstractmethod
    def configure_optimizers(self):
        # REQUIRED
        raise NotImplementedError

    @pl.data_loader
    def tng_dataloader(self):
        # REQUIRED
        raise NotImplementedError
        # return DataLoader(MNIST(os.getcwd(), train=True, download=True,
        # transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        pass

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        pass


#######################
# Utility Modules
class TimeDistributed(Module):
    def __init__(self, module, batch_first=True):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.batch_first = batch_first

    def forward(self, x):

        if len(x.size()) <= 2:
            return self.module(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(-1, x.size(-1))  # (samples * timesteps, input_size)

        y = self.module(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(x.size(0), -1, y.size(-1))  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)

        return y


class Repeater(Module):
    def __init__(self, shape):
        super(Repeater, self).__init__()
        self.shape = shape

    def forward(self, x: torch.Tensor):
        x.unsqueeze_(-2)
        return x.expand(self.shape)

class RNNOutputFilter(Module):

    def __init__(self, return_output=True, only_last=False):
        super(RNNOutputFilter, self).__init__()
        self.only_last = only_last
        self.return_output = return_output

    def forward(self, x: tuple):
        outputs, hidden = x
        out = outputs if self.return_output else hidden
        return out if not self.only_last else out[:, -1, :]


if __name__ == '__main__':
    raise PermissionError('Get out of here - never run this module')