Steffen Illium 2a6100296f restructured
2020-03-11 21:58:08 +01:00

337 lines
15 KiB
Python

from random import choice
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
class CNNRouteGeneratorModel(LightningBaseModule):
name = 'CNNRouteGenerator'
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, alternative = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
element_wise_loss = self.criterion(generated_alternative, alternative)
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
loss = (kld_loss + element_wise_loss) / 2
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, _ = batch_xy
map_array, trajectory, label = batch_x
generated_alternative, z, mu, logvar = self(batch_x)
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
def _test_val_epoch_end(self, outputs, test=False):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
return dict(epoch=self.current_epoch)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
def __init__(self, *params, issubclassed=False):
super(CNNRouteGeneratorModel, self).__init__(*params)
if not issubclassed:
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
length=self.hparams.data_param.dataset_length, normalized=True)
self.criterion = nn.MSELoss()
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters
self.feature_dim = self.hparams.model_param.lat_dim * 10
# NN Nodes
###################################################
#
# Utils
self.activation = nn.ReLU()
self.sigmoid = nn.Sigmoid()
#
# Map Encoder
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=11, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_flat = Flatten(self.map_conv_3.shape)
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
#
# Mixed Encoder
self.mixed_lin = nn.Linear(self.feature_dim, self.feature_dim)
self.mixed_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
#
# Alternative Generator
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
# Todo Fix This Hack!!!!
reshape_shape = (1, self.map_conv_3.shape[1], self.map_conv_3.shape[2])
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, reshape_shape))
self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape)
self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=13, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
conv_padding=1, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, batch_x):
#
# Sorting the Input
map_array, trajectory, label = batch_x
#
# Encode
z, mu, logvar = self.encode(map_array, trajectory, label)
#
# Generate
alt_tensor = self.generate(z)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def generate(self, z):
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
def encode(self, map_array, trajectory, label):
label_array = torch.cat([torch.full((1, 1, self.in_shape[1], self.in_shape[2]), x.item())
for x in label], dim=0)
label_array = self._move_to_model_device(label_array)
combined_tensor = torch.cat((map_array, trajectory, label_array), dim=1)
combined_tensor = self.map_conv_0(combined_tensor)
combined_tensor = self.map_res_1(combined_tensor)
combined_tensor = self.map_conv_1(combined_tensor)
combined_tensor = self.map_res_2(combined_tensor)
combined_tensor = self.map_conv_2(combined_tensor)
combined_tensor = self.map_res_3(combined_tensor)
combined_tensor = self.map_conv_3(combined_tensor)
combined_tensor = self.map_flat(combined_tensor)
combined_tensor = self.map_lin(combined_tensor)
combined_tensor = self.mixed_lin(combined_tensor)
combined_tensor = self.mixed_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.mixed_lin(combined_tensor)
combined_tensor = self.mixed_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
#
# Parameter and Sampling
mu = self.mu(combined_tensor)
logvar = self.logvar(combined_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def generate_random(self, n=6):
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
trajectories = [x.get_random_trajectory() for x in maps]
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
trajectories = self._move_to_model_device(torch.stack(trajectories))
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
maps = self._move_to_model_device(torch.stack(maps))
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
name = 'CNNRouteGeneratorDiscriminated'
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
self._disc = None
self.criterion = nn.BCELoss()
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
length=self.hparams.data_param.dataset_length, normalized=True)