Fixed the Model classes, Visualization

This commit is contained in:
Si11ium
2019-08-23 13:10:47 +02:00
parent 0e879bfdb1
commit 7b0b96eaa3
16 changed files with 141 additions and 469 deletions

View File

@ -1,7 +1,7 @@
from torch.optim import Adam
from networks.auto_encoder import AutoEncoder
from torch.nn.functional import mse_loss
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid
from torch.distributions import Normal
from networks.modules import *
import torch
@ -23,14 +23,10 @@ class AdversarialAutoEncoder(AutoEncoder):
return z, x_hat
class AdversarialAELightningOverrides:
@property
def name(self):
return self.__class__.__name__
def forward(self, x):
return self.network.forward(x)
class AdversarialAELightningOverrides(LightningModuleOverrides):
def __init__(self):
super(AdversarialAELightningOverrides, self).__init__()
def training_step(self, batch, _, optimizer_i):
if optimizer_i == 0:
@ -67,5 +63,12 @@ class AdversarialAELightningOverrides:
raise RuntimeError('This should not have happened, catch me if u can.')
# This is Fucked up, why do i need to put an additional empty list here?
def configure_optimizers(self):
return [Adam(self.network.discriminator.parameters(), lr=0.02),
Adam([*self.network.encoder.parameters(), *self.network.decoder.parameters()], lr=0.02)],\
[]
if __name__ == '__main__':
raise PermissionError('Get out of here - never run this module')

View File

@ -1,3 +1,5 @@
from torch.optim import Adam
from .modules import *
from torch.nn.functional import mse_loss
from torch import Tensor
@ -26,14 +28,10 @@ class AutoEncoder(AbstractNeuralNetwork, ABC):
return z, x_hat
class AutoEncoderLightningOverrides:
class AutoEncoderLightningOverrides(LightningModuleOverrides):
@property
def name(self):
return self.__class__.__name__
def forward(self, x):
return self.network.forward(x)
def __init__(self):
super(AutoEncoderLightningOverrides, self).__init__()
def training_step(self, x, batch_nb):
# z, x_hat
@ -41,6 +39,9 @@ class AutoEncoderLightningOverrides:
loss = mse_loss(x, x_hat)
return {'loss': loss}
def configure_optimizers(self):
return [Adam(self.parameters(), lr=0.02)]
if __name__ == '__main__':
raise PermissionError('Get out of here - never run this module')

View File

@ -1,11 +1,34 @@
import os
import torch
import pytorch_lightning as pl
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU, AvgPool2d
from pytorch_lightning import data_loader
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU
from abc import ABC, abstractmethod
#######################
# Abstract NN Class
# Abstract NN Class & Lightning Module
from torch.utils.data import DataLoader
from dataset import DataContainer
class LightningModuleOverrides:
@property
def name(self):
return self.__class__.__name__
def forward(self, x):
return self.network.forward(x)
@data_loader
def tng_dataloader(self):
num_workers = os.cpu_count() // 2
return DataLoader(DataContainer('data', self.size, self.step),
shuffle=True, batch_size=100, num_workers=num_workers)
class AbstractNeuralNetwork(Module):

View File

@ -1,3 +1,5 @@
from torch.optim import Adam
from networks.auto_encoder import AutoEncoder
from torch.nn.functional import mse_loss
from networks.modules import *
@ -7,16 +9,15 @@ import torch
class SeperatingAdversarialAutoEncoder(Module):
def __init__(self, latent_dim, features, **kwargs):
assert latent_dim % 2 == 0, f'Your latent space needs to be even, not odd, but was: "{latent_dim}"'
super(SeperatingAdversarialAutoEncoder, self).__init__()
self.latent_dim = latent_dim
self.features = features
self.spatial_encoder = PoolingEncoder(self.latent_dim // 2)
self.temporal_encoder = Encoder(self.latent_dim // 2)
self.spatial_encoder = PoolingEncoder(self.latent_dim)
self.temporal_encoder = Encoder(self.latent_dim)
self.decoder = Decoder(self.latent_dim, self.features)
self.spatial_discriminator = Discriminator(self.latent_dim // 2, self.features)
self.temporal_discriminator = Discriminator(self.latent_dim // 2, self.features)
self.spatial_discriminator = Discriminator(self.latent_dim, self.features)
self.temporal_discriminator = Discriminator(self.latent_dim, self.features)
def forward(self, batch):
# Encoder
@ -30,14 +31,10 @@ class SeperatingAdversarialAutoEncoder(Module):
return z_spatial, z_temporal, x_hat
class SeparatingAdversarialAELightningOverrides:
class SeparatingAdversarialAELightningOverrides(LightningModuleOverrides):
@property
def name(self):
return self.__class__.__name__
def forward(self, x):
return self.network.forward(x)
def __init__(self):
super(SeparatingAdversarialAELightningOverrides, self).__init__()
def training_step(self, batch, _, optimizer_i):
spatial_latent_fake, temporal_latent_fake, batch_hat = self.network.forward(batch)
@ -91,6 +88,17 @@ class SeparatingAdversarialAELightningOverrides:
else:
raise RuntimeError('This should not have happened, catch me if u can.')
# This is Fucked up, why do i need to put an additional empty list here?
def configure_optimizers(self):
return [Adam([*self.network.spatial_discriminator.parameters(), *self.network.spatial_encoder.parameters()]
, lr=0.02),
Adam([*self.network.temporal_discriminator.parameters(), *self.network.temporal_encoder.parameters()]
, lr=0.02),
Adam([*self.network.temporal_encoder.parameters(),
*self.network.spatial_encoder.parameters(),
*self.network.decoder.parameters()]
, lr=0.02)], []
if __name__ == '__main__':
raise PermissionError('Get out of here - never run this module')

View File

@ -1,3 +1,5 @@
from torch.optim import Adam
from .modules import *
from torch.nn.functional import mse_loss
@ -33,14 +35,10 @@ class VariationalAutoEncoder(AbstractNeuralNetwork, ABC):
return x_hat, mu, logvar
class VariationalAutoEncoderLightningOverrides:
class VariationalAutoEncoderLightningOverrides(LightningModuleOverrides):
@property
def name(self):
return self.network.name
def forward(self, x):
return self.network.forward(x)
def __init__(self):
super(VariationalAutoEncoderLightningOverrides, self).__init__()
def training_step(self, x, _):
x_hat, logvar, mu = self.forward(x)
@ -53,6 +51,9 @@ class VariationalAutoEncoderLightningOverrides:
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return {'loss': BCE + KLD}
def configure_optimizers(self):
return [Adam(self.parameters(), lr=0.02)]
if __name__ == '__main__':
raise PermissionError('Get out of here - never run this module')