import torch import pytorch_lightning as pl from torch.nn import Module from abc import ABC, abstractmethod ###################### # Abstract Network class following the Lightning Syntax class LightningModule(pl.LightningModule, ABC): def __init__(self): super(LightningModule, self).__init__() @abstractmethod def forward(self, x): raise NotImplementedError @abstractmethod def training_step(self, batch, batch_nb): # REQUIRED raise NotImplementedError def validation_step(self, batch, batch_nb): # OPTIONAL pass def validation_end(self, outputs): # OPTIONAL pass @abstractmethod def configure_optimizers(self): # REQUIRED raise NotImplementedError @pl.data_loader def tng_dataloader(self): # REQUIRED raise NotImplementedError # return DataLoader(MNIST(os.getcwd(), train=True, download=True, # transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def val_dataloader(self): # OPTIONAL pass @pl.data_loader def test_dataloader(self): # OPTIONAL pass ####################### # Utility Modules class TimeDistributed(Module): def __init__(self, module, batch_first=True): super(TimeDistributed, self).__init__() self.module = module self.batch_first = batch_first def forward(self, x): if len(x.size()) <= 2: return self.module(x) # Squash samples and timesteps into a single axis x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) y = self.module(x_reshape) # We have to reshape Y if self.batch_first: y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) else: y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) return y class Repeater(Module): def __init__(self, shape): super(Repeater, self).__init__() self.shape = shape def forward(self, x: torch.Tensor): x.unsqueeze_(-2) return x.expand(self.shape) class RNNOutputFilter(Module): def __init__(self, return_output=True, only_last=False): super(RNNOutputFilter, self).__init__() self.only_last = only_last self.return_output = return_output def forward(self, x: tuple): outputs, hidden = x out = outputs if self.return_output else hidden return out if not self.only_last else out[:, -1, :] if __name__ == '__main__': raise PermissionError('Get out of here - never run this module')