Lightning integration basic ae, dataloaders and dataset

This commit is contained in:
Si11ium
2019-08-14 16:42:33 +02:00
parent fbc776c359
commit fbe0600e24
9 changed files with 637 additions and 1 deletions

52
networks/basic_ae.py Normal file
View File

@ -0,0 +1,52 @@
from torch.nn import Sequential, Linear, GRU
from data.dataset import DataContainer
from .modules import *
#######################
# Basic AE-Implementation
class BasicAE(Module, ABC):
def __init__(self, dataParams, **kwargs):
super(BasicAE, self).__init__()
self.dataParams = dataParams
self.latent_dim = kwargs.get('latent_dim', 2)
self.encoder = self._build_encoder()
self.decoder = self._build_decoder()
def _build_encoder(self):
encoder = Sequential()
encoder.add_module(f'EncoderLinear_{1}', Linear(6, 10, bias=True))
encoder.add_module(f'EncoderLinear_{2}', Linear(10, 10, bias=True))
gru = Sequential()
gru.add_module('Encoder', TimeDistributed(encoder))
gru.add_module('GRU', GRU(10, self.latent_dim))
return gru
def _build_decoder(self):
decoder = Sequential()
decoder.add_module(f'DecoderLinear_{1}', Linear(10, 10, bias=True))
decoder.add_module(f'DecoderLinear_{2}', Linear(10, self.dataParams['features'], bias=True))
gru = Sequential()
# There needs to be ab propper bat
gru.add_module('Repeater', Repeater((1, self.dataParams['size'], -1)))
gru.add_module('GRU', GRU(self.latent_dim, 10))
gru.add_module('GRU Filter', RNNOutputFilter())
gru.add_module('Decoder', TimeDistributed(decoder))
return gru
def forward(self, batch):
batch_size = batch.shape[0]
self.decoder.Repeater.shape = (batch_size, ) + self.decoder.Repeater.shape[-2:]
# outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
outputs, _ = self.encoder(batch)
z = outputs[:, -1]
x_hat = self.decoder(z)
return z, x_hat
if __name__ == '__main__':
raise PermissionError('Get out of here - never run this module')

103
networks/modules.py Normal file
View File

@ -0,0 +1,103 @@
import torch
import pytorch_lightning as pl
from torch.nn import Module
from abc import ABC, abstractmethod
######################
# Abstract Network class following the Lightning Syntax
class LightningModule(pl.LightningModule, ABC):
def __init__(self):
super(LightningModule, self).__init__()
@abstractmethod
def forward(self, x):
raise NotImplementedError
@abstractmethod
def training_step(self, batch, batch_nb):
# REQUIRED
raise NotImplementedError
def validation_step(self, batch, batch_nb):
# OPTIONAL
pass
def validation_end(self, outputs):
# OPTIONAL
pass
@abstractmethod
def configure_optimizers(self):
# REQUIRED
raise NotImplementedError
@pl.data_loader
def tng_dataloader(self):
# REQUIRED
raise NotImplementedError
# return DataLoader(MNIST(os.getcwd(), train=True, download=True,
# transform=transforms.ToTensor()), batch_size=32)
@pl.data_loader
def val_dataloader(self):
# OPTIONAL
pass
@pl.data_loader
def test_dataloader(self):
# OPTIONAL
pass
#######################
# Utility Modules
class TimeDistributed(Module):
def __init__(self, module, batch_first=True):
super(TimeDistributed, self).__init__()
self.module = module
self.batch_first = batch_first
def forward(self, x):
if len(x.size()) <= 2:
return self.module(x)
# Squash samples and timesteps into a single axis
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
y = self.module(x_reshape)
# We have to reshape Y
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
else:
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
return y
class Repeater(Module):
def __init__(self, shape):
super(Repeater, self).__init__()
self.shape = shape
def forward(self, x: torch.Tensor):
x.unsqueeze_(-2)
return x.expand(self.shape)
class RNNOutputFilter(Module):
def __init__(self, return_output=True):
super(RNNOutputFilter, self).__init__()
self.return_output = return_output
def forward(self, x: tuple):
outputs, hidden = x
return outputs if self.return_output else hidden
if __name__ == '__main__':
raise PermissionError('Get out of here - never run this module')