2020-04-08 14:50:16 +02:00

245 lines
8.9 KiB
Python

from functools import reduce
from operator import mul
from random import choice
import torch
from torch import nn
from torch.optim import Adam
from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData
from ml_lib.modules.blocks import ConvModule, DeConvModule
from ml_lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
import variables as V
from generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule):
torch.autograd.set_detect_anomaly(True)
name = 'CNNRouteGenerator'
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, _ = batch_xy
reconstruction, z, mu, logvar = self(batch_x)
recon_loss = self.criterion(reconstruction, batch_x)
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kldivergence
return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence))
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, _ = batch_xy
mu, logvar = self.encoder(batch_x)
z = self.reparameterize(mu, logvar)
reconstruction = self.decoder(mu)
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar)
labels = torch.full((batch_x.shape[0], 1), V.ANY)
return_dict.update(labels=self._move_to_model_device(labels))
return return_dict
def _test_val_epoch_end(self, outputs, test=False):
plt.close('all')
g = GeneratorVisualizer(choice(outputs))
fig = g.draw_io_bundle()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
fig = g.draw_latent()
self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step)
plt.clf()
return dict(epoch=self.current_epoch)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
def __init__(self, *params, issubclassed=False):
super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length)
self.criterion = nn.BCELoss(reduction='sum')
# Additional Attributes
###################################################
self.in_shape = self.dataset.map_shapes_max
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
# NN Nodes
###################################################
self.encoder = Encoder(self.in_shape, self.hparams)
self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams)
def forward(self, batch_x):
# Encode
mu, logvar = self.encoder(batch_x)
# Bottleneck
z = self.reparameterize(mu, logvar)
# Decode
reconstruction = self.decoder(z)
return reconstruction, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = 0.5 * torch.exp(logvar)
eps = torch.randn_like(mu)
z = mu + std * eps
return z
class Encoder(nn.Module):
def __init__(self, in_shape, hparams):
super(Encoder, self).__init__()
# Params
###################################################
self.hparams = hparams
# Additional Attributes
###################################################
self.in_shape = in_shape
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim * 10
# NN Nodes
###################################################
#
# Utils
self.activation = self.hparams.activation()
#
# Encoder
self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.last_conv_shape = self.conv_3.shape
self.flat = Flatten(in_shape=self.last_conv_shape)
self.lin = nn.Linear(self.flat.shape, self.feature_dim)
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
def forward(self, batch_x):
tensor = self.conv_0(batch_x)
tensor = self.conv_1(tensor)
tensor = self.conv_2(tensor)
tensor = self.conv_3(tensor)
tensor = self.flat(tensor)
tensor = self.lin(tensor)
tensor = self.activation(tensor)
#
# Variational
# Parameter for Sampling
mu = self.mu(tensor)
logvar = self.logvar(tensor)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, out_channels, last_conv_shape, hparams):
super(Decoder, self).__init__()
# Params
###################################################
self.hparams = hparams
# Additional Attributes
###################################################
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim
self.out_channels = out_channels
# NN Nodes
###################################################
#
# Utils
self.activation = self.hparams.activation()
#
# Alternative Generator
self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape))
self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape)
self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, z):
tensor = self.lin(z)
tensor = self.activation(tensor)
tensor = self.reshape(tensor)
tensor = self.deconv_1(tensor)
tensor = self.deconv_2(tensor)
tensor = self.deconv_3(tensor)
reconstruction = self.deconv_out(tensor)
return reconstruction