2020-03-13 21:52:33 +01:00

281 lines
12 KiB
Python

from functools import reduce
from operator import mul
from random import choices, choice
import torch
from torch import nn
from torch.optim import Adam
from torchvision.datasets import MNIST
from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
import lib.variables as V
from lib.visualization.generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule):
name = 'CNNRouteGenerator'
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, target = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
target = batch_x if 'ae' in self.hparams.data_param.mode else target
element_wise_loss = self.criterion(generated_alternative, target)
if 'vae' in self.hparams.data_param.mode:
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
loss = kld_loss + element_wise_loss
else:
loss = element_wise_loss
kld_loss = 0
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, _ = batch_xy
if 'vae' in self.hparams.data_param.mode:
z, mu, logvar = self.encode(batch_x)
else:
z = self.encode(batch_x)
mu, logvar = z, z
generated_alternative = self.generate(mu)
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar)
if 'hom' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
elif 'alt' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
elif 'vae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY)
elif 'ae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY)
else:
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
return_dict.update(labels=self._move_to_model_device(labels))
return return_dict
def _test_val_epoch_end(self, outputs, test=False):
plt.close('all')
g = GeneratorVisualizer(choice(outputs))
fig = g.draw_io_bundle()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
fig = g.draw_latent()
self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step)
plt.clf()
return dict(epoch=self.current_epoch)
def on_epoch_start(self):
# self.dataset.seed(self.logger.version)
# torch.random.manual_seed(self.logger.version)
# np.random.seed(self.logger.version)
pass
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
def __init__(self, *params, issubclassed=False):
super(CNNRouteGeneratorModel, self).__init__(*params)
if False:
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length, normalized=True)
self.criterion = nn.MSELoss()
self.dataset = MyMNIST()
# Additional Attributes #
#######################################################
self.in_shape = self.dataset.map_shapes_max
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
########################################################
# NN Nodes
###################################################
#
# Utils
self.activation = nn.ReLU()
self.sigmoid = nn.Sigmoid()
#
# Map Encoder
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
last_conv_shape = self.enc_conv_2b.shape
self.enc_flat = Flatten(last_conv_shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
#
# Mixed Encoder
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
#
# Variational Bottleneck
if 'vae' in self.hparams.data_param.mode:
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
#
# Linear Bottleneck
else:
self.z = nn.Linear(self.feature_dim, self.lat_dim)
#
# Alternative Generator
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=1, conv_kernel=9, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, batch_x):
#
# Encode
if 'vae' in self.hparams.data_param.mode:
z, mu, logvar = self.encode(batch_x)
else:
z = self.encode(batch_x)
mu, logvar = z, z
#
# Generate
alt_tensor = self.generate(z)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def encode(self, batch_x):
combined_tensor = self.enc_conv_0(batch_x)
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_1a(combined_tensor)
combined_tensor = self.enc_conv_1b(combined_tensor)
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_2b(combined_tensor)
# combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
# combined_tensor = self.enc_conv_3a(combined_tensor)
# combined_tensor = self.enc_conv_3b(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
#
# Variational
# Parameter and Sampling
if 'vae' in self.hparams.data_param.mode:
mu = self.mu(combined_tensor)
logvar = self.logvar(combined_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
else:
#
# Linear Bottleneck
z = self.z(combined_tensor)
return z
def generate(self, z):
alt_tensor = self.gen_lin_1(z)
alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.gen_lin_2(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_last_conv(alt_tensor)
alt_tensor = self.gen_deconv_1a(alt_tensor)
alt_tensor = self.gen_deconv_2a(alt_tensor)
# alt_tensor = self.gen_deconv_3a(alt_tensor)
# alt_tensor = self.gen_deconv_3b(alt_tensor)
alt_tensor = self.gen_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor