273 lines
11 KiB
Python
273 lines
11 KiB
Python
from functools import reduce
|
|
from operator import mul
|
|
|
|
from random import choices, choice
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
from torch.optim import Adam
|
|
from torchvision.datasets import MNIST
|
|
|
|
from datasets.mnist import MyMNIST
|
|
from datasets.trajectory_dataset import TrajData
|
|
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
|
from lib.modules.utils import LightningBaseModule, Flatten
|
|
|
|
import matplotlib.pyplot as plt
|
|
import lib.variables as V
|
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
|
|
|
|
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
|
|
|
name = 'CNNRouteGenerator'
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
|
|
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
batch_x, target = batch_xy
|
|
generated_alternative, z, mu, logvar = self(batch_x)
|
|
target = batch_x if 'ae' in self.hparams.data_param.mode else target
|
|
element_wise_loss = self.criterion(generated_alternative, target)
|
|
|
|
if 'vae' in self.hparams.data_param.mode:
|
|
# see Appendix B from VAE paper:
|
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
|
# https://arxiv.org/abs/1312.6114
|
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
|
# kld_loss /= reduce(mul, self.in_shape)
|
|
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
|
|
|
loss = kld_loss + element_wise_loss
|
|
else:
|
|
loss = element_wise_loss
|
|
kld_loss = 0
|
|
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
|
|
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
|
batch_x, _ = batch_xy
|
|
if 'vae' in self.hparams.data_param.mode:
|
|
z, mu, logvar = self.encode(batch_x)
|
|
else:
|
|
z = self.encode(batch_x)
|
|
mu, logvar = z, z
|
|
|
|
generated_alternative = self.generate(mu)
|
|
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar)
|
|
|
|
if 'hom' in self.hparams.data_param.mode:
|
|
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
|
|
elif 'alt' in self.hparams.data_param.mode:
|
|
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
|
|
elif 'vae' in self.hparams.data_param.mode:
|
|
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
|
elif 'ae' in self.hparams.data_param.mode:
|
|
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
|
else:
|
|
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
|
|
|
|
return_dict.update(labels=self._move_to_model_device(labels))
|
|
return return_dict
|
|
|
|
def _test_val_epoch_end(self, outputs, test=False):
|
|
plt.close('all')
|
|
|
|
g = GeneratorVisualizer(choice(outputs))
|
|
fig = g.draw_io_bundle()
|
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
|
plt.clf()
|
|
|
|
fig = g.draw_latent()
|
|
self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step)
|
|
plt.clf()
|
|
|
|
return dict(epoch=self.current_epoch)
|
|
|
|
def on_epoch_start(self):
|
|
# self.dataset.seed(self.logger.version)
|
|
# torch.random.manual_seed(self.logger.version)
|
|
# np.random.seed(self.logger.version)
|
|
pass
|
|
|
|
def validation_step(self, *args):
|
|
return self._test_val_step(*args)
|
|
|
|
def validation_epoch_end(self, outputs: list):
|
|
return self._test_val_epoch_end(outputs)
|
|
|
|
def test_step(self, *args):
|
|
return self._test_val_step(*args)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
return self._test_val_epoch_end(outputs, test=True)
|
|
|
|
def __init__(self, *params, issubclassed=False):
|
|
super(CNNRouteGeneratorModel, self).__init__(*params)
|
|
|
|
if False:
|
|
# Dataset
|
|
self.dataset = TrajData(self.hparams.data_param.map_root,
|
|
mode=self.hparams.data_param.mode,
|
|
preprocessed=self.hparams.data_param.use_preprocessed,
|
|
length=self.hparams.data_param.dataset_length, normalized=True)
|
|
self.criterion = nn.MSELoss()
|
|
|
|
self.dataset = MyMNIST()
|
|
|
|
# Additional Attributes #
|
|
#######################################################
|
|
self.in_shape = self.dataset.map_shapes_max
|
|
self.use_res_net = self.hparams.model_param.use_res_net
|
|
self.lat_dim = self.hparams.model_param.lat_dim
|
|
self.feature_dim = self.lat_dim
|
|
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
|
|
########################################################
|
|
|
|
# NN Nodes
|
|
###################################################
|
|
#
|
|
# Utils
|
|
self.activation = nn.LeakyReLU()
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
#
|
|
# Map Encoder
|
|
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
|
conv_filters=self.hparams.model_param.filters[0],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
|
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[1],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
|
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[2],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
|
conv_filters=self.hparams.model_param.filters[2],
|
|
use_norm=self.hparams.model_param.use_norm,
|
|
use_bias=self.hparams.model_param.use_bias)
|
|
|
|
last_conv_shape = self.enc_conv_3a.shape
|
|
self.enc_flat = Flatten(last_conv_shape)
|
|
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
|
|
|
#
|
|
# Mixed Encoder
|
|
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
|
|
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
|
|
|
#
|
|
# Variational Bottleneck
|
|
if 'vae' in self.hparams.data_param.mode:
|
|
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
|
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
|
|
|
#
|
|
# Linear Bottleneck
|
|
else:
|
|
self.z = nn.Linear(self.feature_dim, self.lat_dim)
|
|
|
|
#
|
|
# Alternative Generator
|
|
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
|
|
|
|
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
|
|
|
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
|
|
|
|
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
|
conv_padding=0, conv_kernel=7, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
|
|
conv_padding=1, conv_kernel=5, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
|
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
|
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
|
use_norm=self.hparams.model_param.use_norm)
|
|
|
|
def forward(self, batch_x):
|
|
#
|
|
# Encode
|
|
if 'vae' in self.hparams.data_param.mode:
|
|
z, mu, logvar = self.encode(batch_x)
|
|
else:
|
|
z = self.encode(batch_x)
|
|
mu, logvar = z, z
|
|
|
|
#
|
|
# Generate
|
|
alt_tensor = self.generate(z)
|
|
return alt_tensor, z, mu, logvar
|
|
|
|
@staticmethod
|
|
def reparameterize(mu, logvar):
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std)
|
|
return mu + eps * std
|
|
|
|
def encode(self, batch_x):
|
|
combined_tensor = self.enc_conv_0(batch_x)
|
|
combined_tensor = self.enc_conv_1a(combined_tensor)
|
|
combined_tensor = self.enc_conv_2a(combined_tensor)
|
|
combined_tensor = self.enc_conv_3a(combined_tensor)
|
|
|
|
combined_tensor = self.enc_flat(combined_tensor)
|
|
combined_tensor = self.enc_lin_1(combined_tensor)
|
|
combined_tensor = self.activation(combined_tensor)
|
|
|
|
combined_tensor = self.enc_lin_2(combined_tensor)
|
|
combined_tensor = self.activation(combined_tensor)
|
|
|
|
#
|
|
# Variational
|
|
# Parameter and Sampling
|
|
if 'vae' in self.hparams.data_param.mode:
|
|
mu = self.mu(combined_tensor)
|
|
logvar = self.logvar(combined_tensor)
|
|
z = self.reparameterize(mu, logvar)
|
|
return z, mu, logvar
|
|
else:
|
|
#
|
|
# Linear Bottleneck
|
|
z = self.z(combined_tensor)
|
|
return z
|
|
|
|
def generate(self, z):
|
|
alt_tensor = self.gen_lin_1(z)
|
|
alt_tensor = self.activation(alt_tensor)
|
|
|
|
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
|
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
|
|
|
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
|
|
|
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
|
|
|
alt_tensor = self.gen_deconv_out(alt_tensor)
|
|
# alt_tensor = self.activation(alt_tensor)
|
|
# alt_tensor = self.sigmoid(alt_tensor)
|
|
return alt_tensor
|