Done: AE, VAE, AAE
ToDo: Double AAE, Visualization All Modularized
This commit is contained in:
53
networks/variational_auto_encoder.py
Normal file
53
networks/variational_auto_encoder.py
Normal file
@ -0,0 +1,53 @@
|
||||
from .modules import *
|
||||
from torch.nn.functional import mse_loss
|
||||
|
||||
|
||||
#######################
|
||||
# Basic AE-Implementation
|
||||
class VariationalAutoEncoder(Module, ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self, dataParams, **kwargs):
|
||||
super(VariationalAutoEncoder, self).__init__()
|
||||
self.dataParams = dataParams
|
||||
self.latent_dim = kwargs.get('latent_dim', 2)
|
||||
self.encoder = Encoder(self.latent_dim, variational=True)
|
||||
self.decoder = Decoder(self.latent_dim, self.dataParams['features'], variational=True)
|
||||
|
||||
@staticmethod
|
||||
def reparameterize(mu, logvar):
|
||||
# Lambda Layer, add gaussian noise
|
||||
std = torch.exp(0.5*logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps*std
|
||||
|
||||
def forward(self, batch):
|
||||
mu, logvar = self.encoder(batch)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
repeat = Repeater((batch.shape[0], self.dataParams['size'], -1))
|
||||
x_hat = self.decoder(repeat(z))
|
||||
return x_hat, mu, logvar
|
||||
|
||||
|
||||
class VariationalAutoEncoderLightningOverrides:
|
||||
|
||||
def forward(self, x):
|
||||
return self.network.forward(x)
|
||||
|
||||
def training_step(self, x, _):
|
||||
x_hat, logvar, mu = self.forward(x)
|
||||
BCE = mse_loss(x_hat, x, reduction='mean')
|
||||
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return {'loss': BCE + KLD}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise PermissionError('Get out of here - never run this module')
|
Reference in New Issue
Block a user