import torch import pytorch_lightning as pl from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU, AvgPool2d from abc import ABC, abstractmethod ####################### # Abstract NN Class class AbstractNeuralNetwork(Module): @property def name(self): return self.__class__.__name__ def __init__(self): super(AbstractNeuralNetwork, self).__init__() def forward(self, batch): pass ###################### # Abstract Network class following the Lightning Syntax class LightningModule(pl.LightningModule, ABC): def __init__(self): super(LightningModule, self).__init__() @abstractmethod def forward(self, x): raise NotImplementedError @abstractmethod def training_step(self, batch, batch_nb): # REQUIRED raise NotImplementedError def validation_step(self, batch, batch_nb): # OPTIONAL pass def validation_end(self, outputs): # OPTIONAL pass @abstractmethod def configure_optimizers(self): # REQUIRED raise NotImplementedError @pl.data_loader def tng_dataloader(self): # REQUIRED raise NotImplementedError # return DataLoader(MNIST(os.getcwd(), train=True, download=True, # transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def val_dataloader(self): # OPTIONAL pass @pl.data_loader def test_dataloader(self): # OPTIONAL pass ####################### # Utility Modules class TimeDistributed(Module): def __init__(self, module, batch_first=True): super(TimeDistributed, self).__init__() self.module = module self.batch_first = batch_first def forward(self, x): if len(x.size()) <= 2: return self.module(x) # Squash samples and timesteps into a single axis x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) y = self.module(x_reshape) # We have to reshape Y if self.batch_first: y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) else: y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) return y class Repeater(Module): def __init__(self, shape): super(Repeater, self).__init__() self.shape = shape def forward(self, x: torch.Tensor): x = x.unsqueeze(-2) return x.expand(self.shape) class RNNOutputFilter(Module): def __init__(self, return_output=True, only_last=False): super(RNNOutputFilter, self).__init__() self.only_last = only_last self.return_output = return_output def forward(self, x: tuple): outputs, hidden = x out = outputs if self.return_output else hidden return out if not self.only_last else out[:, -1, :] class AvgDimPool(Module): def __init__(self): super(AvgDimPool, self).__init__() def forward(self, x): return x.mean(-2) ####################### # Network Modules # Generators, Decoders, Encoders, Discriminators class Discriminator(Module): def __init__(self, latent_dim, features, dropout=.0, activation=ReLU): super(Discriminator, self).__init__() self.features = features self.latent_dim = latent_dim self.l1 = Linear(self.latent_dim, self.features * 10) self.l2 = Linear(self.features * 10, self.features * 20) self.lout = Linear(self.features * 20, 1) self.dropout = Dropout(dropout) self.activation = activation() self.sigmoid = Sigmoid() def forward(self, x, **kwargs): tensor = self.l1(x) tensor = self.dropout(self.activation(tensor)) tensor = self.l2(tensor) tensor = self.dropout(self.activation(tensor)) tensor = self.lout(tensor) tensor = self.sigmoid(tensor) return tensor class DecoderLinearStack(Module): def __init__(self, out_shape): super(DecoderLinearStack, self).__init__() self.l1 = Linear(10, 100, bias=True) self.l2 = Linear(100, out_shape, bias=True) self.activation = ReLU() self.activation_out = Tanh() def forward(self, x): tensor = self.l1(x) tensor = self.activation(tensor) tensor = self.l2(tensor) tensor = self.activation_out(tensor) return tensor class EncoderLinearStack(Module): def __init__(self): super(EncoderLinearStack, self).__init__() # FixMe: Get Hardcoded shit out of here self.l1 = Linear(6, 100, bias=True) self.l2 = Linear(100, 10, bias=True) self.activation = ReLU() def forward(self, x): tensor = self.l1(x) tensor = self.activation(tensor) tensor = self.l2(tensor) tensor = self.activation(tensor) return tensor class Encoder(Module): def __init__(self, lat_dim, variational=False): self.lat_dim = lat_dim self.variational = variational super(Encoder, self).__init__() self.l_stack = TimeDistributed(EncoderLinearStack()) self.gru = GRU(10, 10, batch_first=True) self.filter = RNNOutputFilter(only_last=True) if variational: self.mu = Linear(10, self.lat_dim) self.logvar = Linear(10, self.lat_dim) else: self.lat_dim_layer = Linear(10, self.lat_dim) def forward(self, x): tensor = self.l_stack(x) tensor = self.gru(tensor) tensor = self.filter(tensor) if self.variational: tensor = self.mu(tensor), self.logvar(tensor) else: tensor = self.lat_dim_layer(tensor) return tensor class PoolingEncoder(Module): def __init__(self, lat_dim, variational=False): self.lat_dim = lat_dim self.variational = variational super(PoolingEncoder, self).__init__() self.p = AvgDimPool() self.l = EncoderLinearStack() if variational: self.mu = Linear(10, self.lat_dim) self.logvar = Linear(10, self.lat_dim) else: self.lat_dim_layer = Linear(10, self.lat_dim) def forward(self, x): tensor = self.p(x) tensor = self.l(tensor) if self.variational: tensor = self.mu(tensor), self.logvar(tensor) else: tensor = self.lat_dim_layer(tensor) return tensor class Decoder(Module): def __init__(self, latent_dim, *args, variational=False): self.variational = variational super(Decoder, self).__init__() self.g = GRU(latent_dim, 10, batch_first=True) self.filter = RNNOutputFilter() self.l_stack = TimeDistributed(DecoderLinearStack(*args)) pass def forward(self, x): tensor = self.g(x) tensor = self.filter(tensor) tensor = self.l_stack(tensor) return tensor if __name__ == '__main__': raise PermissionError('Get out of here - never run this module')