Lightning integration basic ae, dataloaders and dataset
This commit is contained in:
52
networks/basic_ae.py
Normal file
52
networks/basic_ae.py
Normal file
@ -0,0 +1,52 @@
|
||||
from torch.nn import Sequential, Linear, GRU
|
||||
from data.dataset import DataContainer
|
||||
|
||||
from .modules import *
|
||||
|
||||
|
||||
#######################
|
||||
# Basic AE-Implementation
|
||||
class BasicAE(Module, ABC):
|
||||
|
||||
def __init__(self, dataParams, **kwargs):
|
||||
super(BasicAE, self).__init__()
|
||||
self.dataParams = dataParams
|
||||
self.latent_dim = kwargs.get('latent_dim', 2)
|
||||
self.encoder = self._build_encoder()
|
||||
self.decoder = self._build_decoder()
|
||||
|
||||
|
||||
def _build_encoder(self):
|
||||
encoder = Sequential()
|
||||
encoder.add_module(f'EncoderLinear_{1}', Linear(6, 10, bias=True))
|
||||
encoder.add_module(f'EncoderLinear_{2}', Linear(10, 10, bias=True))
|
||||
gru = Sequential()
|
||||
gru.add_module('Encoder', TimeDistributed(encoder))
|
||||
gru.add_module('GRU', GRU(10, self.latent_dim))
|
||||
return gru
|
||||
|
||||
def _build_decoder(self):
|
||||
decoder = Sequential()
|
||||
decoder.add_module(f'DecoderLinear_{1}', Linear(10, 10, bias=True))
|
||||
decoder.add_module(f'DecoderLinear_{2}', Linear(10, self.dataParams['features'], bias=True))
|
||||
|
||||
gru = Sequential()
|
||||
# There needs to be ab propper bat
|
||||
gru.add_module('Repeater', Repeater((1, self.dataParams['size'], -1)))
|
||||
gru.add_module('GRU', GRU(self.latent_dim, 10))
|
||||
gru.add_module('GRU Filter', RNNOutputFilter())
|
||||
gru.add_module('Decoder', TimeDistributed(decoder))
|
||||
return gru
|
||||
|
||||
def forward(self, batch):
|
||||
batch_size = batch.shape[0]
|
||||
self.decoder.Repeater.shape = (batch_size, ) + self.decoder.Repeater.shape[-2:]
|
||||
# outputs, hidden (Batch, Timesteps aka. Size, Features / Latent Dim Size)
|
||||
outputs, _ = self.encoder(batch)
|
||||
z = outputs[:, -1]
|
||||
x_hat = self.decoder(z)
|
||||
return z, x_hat
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
103
networks/modules.py
Normal file
103
networks/modules.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch.nn import Module
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
######################
|
||||
# Abstract Network class following the Lightning Syntax
|
||||
class LightningModule(pl.LightningModule, ABC):
|
||||
|
||||
def __init__(self):
|
||||
super(LightningModule, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, batch, batch_nb):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
# return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||
# transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
|
||||
#######################
|
||||
# Utility Modules
|
||||
class TimeDistributed(Module):
|
||||
def __init__(self, module, batch_first=True):
|
||||
super(TimeDistributed, self).__init__()
|
||||
self.module = module
|
||||
self.batch_first = batch_first
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if len(x.size()) <= 2:
|
||||
return self.module(x)
|
||||
|
||||
# Squash samples and timesteps into a single axis
|
||||
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
|
||||
|
||||
y = self.module(x_reshape)
|
||||
|
||||
# We have to reshape Y
|
||||
if self.batch_first:
|
||||
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
|
||||
else:
|
||||
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Repeater(Module):
|
||||
def __init__(self, shape):
|
||||
super(Repeater, self).__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x.unsqueeze_(-2)
|
||||
return x.expand(self.shape)
|
||||
|
||||
class RNNOutputFilter(Module):
|
||||
|
||||
def __init__(self, return_output=True):
|
||||
super(RNNOutputFilter, self).__init__()
|
||||
self.return_output = return_output
|
||||
|
||||
def forward(self, x: tuple):
|
||||
outputs, hidden = x
|
||||
return outputs if self.return_output else hidden
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
Reference in New Issue
Block a user