from torch.optim import Adam

from .modules import *
from torch.nn.functional import mse_loss
from torch import Tensor


#######################
# Basic AE-Implementation
class AutoEncoder(AbstractNeuralNetwork, ABC):

    def __init__(self, latent_dim: int=0, features: int = 0, use_norm=True,
                 train_on_predictions=False, **kwargs):
        assert latent_dim and features
        super(AutoEncoder, self).__init__()
        self.train_on_predictions = train_on_predictions
        self.latent_dim = latent_dim
        self.features = features
        self.encoder = Encoder(self.latent_dim, use_norm=use_norm)
        self.decoder = Decoder(self.latent_dim, self.features, use_norm=use_norm)

    def forward(self, batch: Tensor):
        # Encoder
        #  outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
        z = self.encoder(batch)
        # Decoder
        # First repeat the data accordingly to the batch size
        z_repeatet = Repeater((batch.shape[0], batch.shape[1], -1))(z)
        x_hat = self.decoder(z_repeatet)
        return z, x_hat

    def training_step(self, batch, batch_nb):
        x, y = batch
        # z, x_hat
        _, x_hat = self.forward(x)
        loss = mse_loss(y, x_hat) if self.train_on_predictions else mse_loss(x, x_hat)
        return {'loss': loss}

    def configure_optimizers(self):
        return [Adam(self.parameters(), lr=0.02)]


if __name__ == '__main__':
    raise PermissionError('Get out of here - never run this module')