from torch.nn import Sequential, Linear, GRU from data.dataset import DataContainer from .modules import * ####################### # Basic AE-Implementation class BasicAE(Module, ABC): def __init__(self, dataParams, **kwargs): super(BasicAE, self).__init__() self.dataParams = dataParams self.latent_dim = kwargs.get('latent_dim', 2) self.encoder = self._build_encoder() self.decoder = self._build_decoder() def _build_encoder(self): encoder = Sequential() encoder.add_module(f'EncoderLinear_{1}', Linear(6, 10, bias=True)) encoder.add_module(f'EncoderLinear_{2}', Linear(10, 10, bias=True)) gru = Sequential() gru.add_module('Encoder', TimeDistributed(encoder)) gru.add_module('GRU', GRU(10, self.latent_dim)) return gru def _build_decoder(self): decoder = Sequential() decoder.add_module(f'DecoderLinear_{1}', Linear(10, 10, bias=True)) decoder.add_module(f'DecoderLinear_{2}', Linear(10, self.dataParams['features'], bias=True)) gru = Sequential() # There needs to be ab propper bat gru.add_module('Repeater', Repeater((1, self.dataParams['size'], -1))) gru.add_module('GRU', GRU(self.latent_dim, 10)) gru.add_module('GRU Filter', RNNOutputFilter()) gru.add_module('Decoder', TimeDistributed(decoder)) return gru def forward(self, batch): batch_size = batch.shape[0] self.decoder.Repeater.shape = (batch_size, ) + self.decoder.Repeater.shape[-2:] # outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size) outputs, _ = self.encoder(batch) z = outputs[:, -1] x_hat = self.decoder(z) return z, x_hat if __name__ == '__main__': raise PermissionError('Get out of here - never run this module')