Lightning integration basic ae, dataloaders and dataset
This commit is contained in:
103
networks/modules.py
Normal file
103
networks/modules.py
Normal file
@ -0,0 +1,103 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch.nn import Module
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
######################
|
||||
# Abstract Network class following the Lightning Syntax
|
||||
class LightningModule(pl.LightningModule, ABC):
|
||||
|
||||
def __init__(self):
|
||||
super(LightningModule, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, batch, batch_nb):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
def validation_step(self, batch, batch_nb):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure_optimizers(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
|
||||
@pl.data_loader
|
||||
def tng_dataloader(self):
|
||||
# REQUIRED
|
||||
raise NotImplementedError
|
||||
# return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||
# transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
# OPTIONAL
|
||||
pass
|
||||
|
||||
|
||||
#######################
|
||||
# Utility Modules
|
||||
class TimeDistributed(Module):
|
||||
def __init__(self, module, batch_first=True):
|
||||
super(TimeDistributed, self).__init__()
|
||||
self.module = module
|
||||
self.batch_first = batch_first
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
if len(x.size()) <= 2:
|
||||
return self.module(x)
|
||||
|
||||
# Squash samples and timesteps into a single axis
|
||||
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
|
||||
|
||||
y = self.module(x_reshape)
|
||||
|
||||
# We have to reshape Y
|
||||
if self.batch_first:
|
||||
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
|
||||
else:
|
||||
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class Repeater(Module):
|
||||
def __init__(self, shape):
|
||||
super(Repeater, self).__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x.unsqueeze_(-2)
|
||||
return x.expand(self.shape)
|
||||
|
||||
class RNNOutputFilter(Module):
|
||||
|
||||
def __init__(self, return_output=True):
|
||||
super(RNNOutputFilter, self).__init__()
|
||||
self.return_output = return_output
|
||||
|
||||
def forward(self, x: tuple):
|
||||
outputs, hidden = x
|
||||
return outputs if self.return_output else hidden
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
Reference in New Issue
Block a user