import os from operator import mul from functools import reduce import torch from torch import randn import pytorch_lightning as pl from torch.nn import Module, Linear, ReLU, Sigmoid, Dropout, GRU, Tanh from abc import ABC, abstractmethod ####################### # Abstract NN Class & Lightning Module from torch.utils.data import DataLoader from dataset import DataContainer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class LightningModuleOverrides: @property def name(self): return self.__class__.__name__ @pl.data_loader def train_dataloader(self): num_workers = 0 # os.cpu_count() // 2 return DataLoader(DataContainer(os.path.join('data', 'training'), self.size, self.step), shuffle=True, batch_size=10000, num_workers=num_workers) class AbstractNeuralNetwork(Module): @property def name(self): return self.__class__.__name__ def __init__(self): super(AbstractNeuralNetwork, self).__init__() def forward(self, batch): pass ####################### # Utility Modules class TimeDistributed(Module): def __init__(self, module, batch_first=True): super(TimeDistributed, self).__init__() self.module = module self.batch_first = batch_first def forward(self, x): if len(x.size()) <= 2: return self.module(x) # Squash samples and timesteps into a single axis x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) y = self.module(x_reshape) # We have to reshape Y if self.batch_first: y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) else: y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) return y class Repeater(Module): def __init__(self, shape): super(Repeater, self).__init__() self.shape = shape def forward(self, x: torch.Tensor): x = x.unsqueeze(-2) return x.expand(self.shape) class RNNOutputFilter(Module): def __init__(self, return_output=True, only_last=False): super(RNNOutputFilter, self).__init__() self.only_last = only_last self.return_output = return_output def forward(self, x: tuple): outputs, hidden = x out = outputs if self.return_output else hidden return out if not self.only_last else out[:, -1, :] class AvgDimPool(Module): def __init__(self): super(AvgDimPool, self).__init__() def forward(self, x): return x.mean(-2) ####################### # Network Modules # Generators, Decoders, Encoders, Discriminators class Discriminator(Module): def __init__(self, latent_dim, features, dropout=.0, activation=ReLU, use_norm=False): super(Discriminator, self).__init__() self.features = features self.latent_dim = latent_dim self.l1 = Linear(self.latent_dim, self.features * 10) self.norm1 = torch.nn.BatchNorm1d(self.features * 10) if use_norm else False self.l2 = Linear(self.features * 10, self.features * 20) self.norm2 = torch.nn.BatchNorm1d(self.features * 20) if use_norm else False self.lout = Linear(self.features * 20, 1) self.dropout = Dropout(dropout) self.activation = activation() self.sigmoid = Sigmoid() def forward(self, x, **kwargs): tensor = self.l1(x) tensor = self.dropout(tensor) if self.norm1: tensor = self.norm1(tensor) tensor = self.activation(tensor) tensor = self.l2(tensor) tensor = self.dropout(tensor) if self.norm2: tensor = self.norm2(tensor) tensor = self.activation(tensor) tensor = self.lout(tensor) tensor = self.sigmoid(tensor) return tensor class DecoderLinearStack(Module): def __init__(self, out_shape, use_norm=True): super(DecoderLinearStack, self).__init__() self.l1 = Linear(10, 100, bias=True) self.norm1 = torch.nn.BatchNorm1d(100) if use_norm else False self.l2 = Linear(100, out_shape, bias=True) self.norm2 = torch.nn.BatchNorm1d(out_shape) if use_norm else False self.activation = ReLU() self.activation_out = Tanh() def forward(self, x): tensor = self.l1(x) if self.norm1: tensor = self.norm1(tensor) tensor = self.activation(tensor) tensor = self.l2(tensor) if self.norm2: tensor = self.norm2(tensor) tensor = self.activation_out(tensor) return tensor class EncoderLinearStack(Module): @property def shape(self): x = randn(self.features).unsqueeze(0) x = torch.cat((x,x,x,x,x)) output = self(x) return output.shape[1:] def __init__(self, features=6, factor=10, use_bias=True, use_norm=True): super(EncoderLinearStack, self).__init__() # FixMe: Get Hardcoded shit out of here self.features = features self.l1 = Linear(self.features, self.features * factor, bias=use_bias) self.l2 = Linear(self.features * factor, self.features * factor//2, bias=use_bias) self.l3 = Linear(self.features * factor//2, factor, use_bias) self.norm1 = torch.nn.BatchNorm1d(self.features * factor) if use_norm else False self.norm2 = torch.nn.BatchNorm1d(self.features * factor//2) if use_norm else False self.norm3 = torch.nn.BatchNorm1d(factor) if use_norm else False self.activation = ReLU() def forward(self, x): tensor = self.l1(x) if self.norm1: tensor = self.norm1(tensor) tensor = self.activation(tensor) tensor = self.l2(tensor) if self.norm2: tensor = self.norm2(tensor) tensor = self.activation(tensor) tensor = self.l3(tensor) if self.norm3: tensor = self.norm3(tensor) tensor = self.activation(tensor) return tensor class Encoder(Module): def __init__(self, lat_dim, variational=False, use_dense=True, features=6, use_norm=True): self.lat_dim = lat_dim self.features = features self.lstm_cells = 10 self.variational = variational super(Encoder, self).__init__() self.l_stack = TimeDistributed(EncoderLinearStack(features, use_norm=use_norm)) if use_dense else False self.gru = GRU(10 if use_dense else self.features, self.lstm_cells, batch_first=True) self.filter = RNNOutputFilter(only_last=True) self.norm = torch.nn.BatchNorm1d(self.lstm_cells) if use_norm else False if variational: self.mu = Linear(self.lstm_cells, self.lat_dim) self.logvar = Linear(self.lstm_cells, self.lat_dim) else: self.lat_dim_layer = Linear(self.lstm_cells, self.lat_dim) def forward(self, x): if self.l_stack: x = self.l_stack(x) tensor = self.gru(x) tensor = self.filter(tensor) if self.norm: tensor = self.norm(tensor) if self.variational: tensor = self.mu(tensor), self.logvar(tensor) else: tensor = self.lat_dim_layer(tensor) return tensor class AttentionEncoder(Module): def __init__(self): super(AttentionEncoder, self).__init__() self.l_stack = TimeDistributed(EncoderLinearStack()) def forward(self, x): tensor = self.l_stack(x) torch.bmm() # TODO Add Attention here return tensor class PoolingEncoder(Module): def __init__(self, lat_dim, variational=False, use_norm=True): self.lat_dim = lat_dim self.variational = variational super(PoolingEncoder, self).__init__() self.p = AvgDimPool() self.l = EncoderLinearStack(use_norm=use_norm) if variational: self.mu = Linear(self.l.shape, self.lat_dim) self.logvar = Linear(self.l.shape, self.lat_dim) else: self.lat_dim_layer = Linear(reduce(mul, self.l.shape), self.lat_dim) def forward(self, x): tensor = self.p(x) tensor = self.l(tensor) if self.variational: tensor = self.mu(tensor), self.logvar(tensor) else: tensor = self.lat_dim_layer(tensor) return tensor class Decoder(Module): def __init__(self, latent_dim, *args, lstm_cells=10, use_norm=True, variational=False): self.variational = variational super(Decoder, self).__init__() self.gru = GRU(latent_dim, lstm_cells, batch_first=True) self.norm = TimeDistributed(torch.nn.BatchNorm1d(lstm_cells) if use_norm else False) self.filter = RNNOutputFilter() self.l_stack = TimeDistributed(DecoderLinearStack(*args, use_norm=use_norm)) pass def forward(self, x): tensor = self.gru(x) tensor = self.filter(tensor) if self.norm: tensor = self.norm(tensor) tensor = self.l_stack(tensor) return tensor if __name__ == '__main__': raise PermissionError('Get out of here - never run this module')