Fixed the Model classes, Visualization
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from torch.optim import Adam
|
||||
|
||||
from networks.auto_encoder import AutoEncoder
|
||||
from torch.nn.functional import mse_loss
|
||||
from torch.nn import Sequential, Linear, ReLU, Dropout, Sigmoid
|
||||
from torch.distributions import Normal
|
||||
from networks.modules import *
|
||||
import torch
|
||||
|
||||
@ -23,14 +23,10 @@ class AdversarialAutoEncoder(AutoEncoder):
|
||||
return z, x_hat
|
||||
|
||||
|
||||
class AdversarialAELightningOverrides:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
class AdversarialAELightningOverrides(LightningModuleOverrides):
|
||||
|
||||
def __init__(self):
|
||||
super(AdversarialAELightningOverrides, self).__init__()
|
||||
|
||||
def training_step(self, batch, _, optimizer_i):
|
||||
if optimizer_i == 0:
|
||||
@ -67,5 +63,12 @@ class AdversarialAELightningOverrides:
|
||||
raise RuntimeError('This should not have happened, catch me if u can.')
|
||||
|
||||
|
||||
# This is Fucked up, why do i need to put an additional empty list here?
|
||||
def configure_optimizers(self):
|
||||
return [Adam(self.network.discriminator.parameters(), lr=0.02),
|
||||
Adam([*self.network.encoder.parameters(), *self.network.decoder.parameters()], lr=0.02)],\
|
||||
[]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
||||
|
@ -1,3 +1,5 @@
|
||||
from torch.optim import Adam
|
||||
|
||||
from .modules import *
|
||||
from torch.nn.functional import mse_loss
|
||||
from torch import Tensor
|
||||
@ -26,14 +28,10 @@ class AutoEncoder(AbstractNeuralNetwork, ABC):
|
||||
return z, x_hat
|
||||
|
||||
|
||||
class AutoEncoderLightningOverrides:
|
||||
class AutoEncoderLightningOverrides(LightningModuleOverrides):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
def __init__(self):
|
||||
super(AutoEncoderLightningOverrides, self).__init__()
|
||||
|
||||
def training_step(self, x, batch_nb):
|
||||
# z, x_hat
|
||||
@ -41,6 +39,9 @@ class AutoEncoderLightningOverrides:
|
||||
loss = mse_loss(x, x_hat)
|
||||
return {'loss': loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
||||
|
@ -1,11 +1,34 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU, AvgPool2d
|
||||
from pytorch_lightning import data_loader
|
||||
from torch.nn import Module, Linear, ReLU, Tanh, Sigmoid, Dropout, GRU
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
#######################
|
||||
# Abstract NN Class
|
||||
# Abstract NN Class & Lightning Module
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset import DataContainer
|
||||
|
||||
|
||||
class LightningModuleOverrides:
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self):
|
||||
num_workers = os.cpu_count() // 2
|
||||
return DataLoader(DataContainer('data', self.size, self.step),
|
||||
shuffle=True, batch_size=100, num_workers=num_workers)
|
||||
|
||||
|
||||
class AbstractNeuralNetwork(Module):
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from torch.optim import Adam
|
||||
|
||||
from networks.auto_encoder import AutoEncoder
|
||||
from torch.nn.functional import mse_loss
|
||||
from networks.modules import *
|
||||
@ -7,16 +9,15 @@ import torch
|
||||
class SeperatingAdversarialAutoEncoder(Module):
|
||||
|
||||
def __init__(self, latent_dim, features, **kwargs):
|
||||
assert latent_dim % 2 == 0, f'Your latent space needs to be even, not odd, but was: "{latent_dim}"'
|
||||
super(SeperatingAdversarialAutoEncoder, self).__init__()
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.features = features
|
||||
self.spatial_encoder = PoolingEncoder(self.latent_dim // 2)
|
||||
self.temporal_encoder = Encoder(self.latent_dim // 2)
|
||||
self.spatial_encoder = PoolingEncoder(self.latent_dim)
|
||||
self.temporal_encoder = Encoder(self.latent_dim)
|
||||
self.decoder = Decoder(self.latent_dim, self.features)
|
||||
self.spatial_discriminator = Discriminator(self.latent_dim // 2, self.features)
|
||||
self.temporal_discriminator = Discriminator(self.latent_dim // 2, self.features)
|
||||
self.spatial_discriminator = Discriminator(self.latent_dim, self.features)
|
||||
self.temporal_discriminator = Discriminator(self.latent_dim, self.features)
|
||||
|
||||
def forward(self, batch):
|
||||
# Encoder
|
||||
@ -30,14 +31,10 @@ class SeperatingAdversarialAutoEncoder(Module):
|
||||
return z_spatial, z_temporal, x_hat
|
||||
|
||||
|
||||
class SeparatingAdversarialAELightningOverrides:
|
||||
class SeparatingAdversarialAELightningOverrides(LightningModuleOverrides):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
def __init__(self):
|
||||
super(SeparatingAdversarialAELightningOverrides, self).__init__()
|
||||
|
||||
def training_step(self, batch, _, optimizer_i):
|
||||
spatial_latent_fake, temporal_latent_fake, batch_hat = self.network.forward(batch)
|
||||
@ -91,6 +88,17 @@ class SeparatingAdversarialAELightningOverrides:
|
||||
else:
|
||||
raise RuntimeError('This should not have happened, catch me if u can.')
|
||||
|
||||
# This is Fucked up, why do i need to put an additional empty list here?
|
||||
def configure_optimizers(self):
|
||||
return [Adam([*self.network.spatial_discriminator.parameters(), *self.network.spatial_encoder.parameters()]
|
||||
, lr=0.02),
|
||||
Adam([*self.network.temporal_discriminator.parameters(), *self.network.temporal_encoder.parameters()]
|
||||
, lr=0.02),
|
||||
Adam([*self.network.temporal_encoder.parameters(),
|
||||
*self.network.spatial_encoder.parameters(),
|
||||
*self.network.decoder.parameters()]
|
||||
, lr=0.02)], []
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
||||
|
@ -1,3 +1,5 @@
|
||||
from torch.optim import Adam
|
||||
|
||||
from .modules import *
|
||||
from torch.nn.functional import mse_loss
|
||||
|
||||
@ -33,14 +35,10 @@ class VariationalAutoEncoder(AbstractNeuralNetwork, ABC):
|
||||
return x_hat, mu, logvar
|
||||
|
||||
|
||||
class VariationalAutoEncoderLightningOverrides:
|
||||
class VariationalAutoEncoderLightningOverrides(LightningModuleOverrides):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.network.name
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
def __init__(self):
|
||||
super(VariationalAutoEncoderLightningOverrides, self).__init__()
|
||||
|
||||
def training_step(self, x, _):
|
||||
x_hat, logvar, mu = self.forward(x)
|
||||
@ -53,6 +51,9 @@ class VariationalAutoEncoderLightningOverrides:
|
||||
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return {'loss': BCE + KLD}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
||||
|
Reference in New Issue
Block a user