from random import choice import torch from functools import reduce from operator import mul from torch import nn from torch.optim import Adam from datasets.trajectory_dataset import TrajData from lib.evaluation.classification import ROCEvaluation from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule from lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt class CNNRouteGeneratorModel(LightningBaseModule): name = 'CNNRouteGenerator' def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.train_param.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, alternative = batch_xy generated_alternative, z, mu, logvar = self(batch_x) element_wise_loss = self.criterion(generated_alternative, alternative) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing TODO: Does This make sense? Sanity Check it! # kld_loss /= reduce(mul, self.in_shape) kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 loss = (kld_loss + element_wise_loss) / 2 return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, _ = batch_xy map_array, trajectory, label = batch_x generated_alternative, z, mu, logvar = self(batch_x) return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1) def _test_val_epoch_end(self, outputs, test=False): maps, trajectories, labels, val_restul_dict = self.generate_random() from lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) return dict(epoch=self.current_epoch) def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs: list): return self._test_val_epoch_end(outputs) def test_step(self, *args): return self._test_val_step(*args) def test_epoch_end(self, outputs): return self._test_val_epoch_end(outputs, test=True) def __init__(self, *params, issubclassed=False): super(CNNRouteGeneratorModel, self).__init__(*params) if not issubclassed: # Dataset self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays', length=self.hparams.data_param.dataset_length, normalized=True) self.criterion = nn.MSELoss() # Additional Attributes self.in_shape = self.dataset.map_shapes_max # Todo: Better naming and size in Parameters self.feature_dim = self.hparams.model_param.lat_dim * 10 # NN Nodes ################################################### # # Utils self.activation = nn.ReLU() self.sigmoid = nn.Sigmoid() # # Map Encoder self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=7, conv_stride=1, conv_padding=3, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=11, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_flat = Flatten(self.map_conv_3.shape) self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim) # # Mixed Encoder self.mixed_lin = nn.Linear(self.feature_dim, self.feature_dim) self.mixed_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x # # Variational Bottleneck self.mu = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim) self.logvar = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim) # # Alternative Generator self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) # Todo Fix This Hack!!!! reshape_shape = (1, self.map_conv_3.shape[1], self.map_conv_3.shape[2]) self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, reshape_shape)) self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape) self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2], conv_padding=0, conv_kernel=13, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1], conv_padding=0, conv_kernel=7, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0], conv_padding=1, conv_kernel=5, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None, conv_padding=1, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) def forward(self, batch_x): # # Sorting the Input map_array, trajectory, label = batch_x # # Encode z, mu, logvar = self.encode(map_array, trajectory, label) # # Generate alt_tensor = self.generate(z) return alt_tensor, z, mu, logvar @staticmethod def reparameterize(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def generate(self, z): alt_tensor = self.alt_lin_1(z) alt_tensor = self.activation(alt_tensor) alt_tensor = self.alt_lin_2(alt_tensor) alt_tensor = self.activation(alt_tensor) alt_tensor = self.reshape_to_map(alt_tensor) alt_tensor = self.alt_deconv_1(alt_tensor) alt_tensor = self.alt_deconv_2(alt_tensor) alt_tensor = self.alt_deconv_3(alt_tensor) alt_tensor = self.alt_deconv_out(alt_tensor) # alt_tensor = self.activation(alt_tensor) alt_tensor = self.sigmoid(alt_tensor) return alt_tensor def encode(self, map_array, trajectory, label): label_array = torch.cat([torch.full((1, 1, self.in_shape[1], self.in_shape[2]), x.item()) for x in label], dim=0) label_array = self._move_to_model_device(label_array) combined_tensor = torch.cat((map_array, trajectory, label_array), dim=1) combined_tensor = self.map_conv_0(combined_tensor) combined_tensor = self.map_res_1(combined_tensor) combined_tensor = self.map_conv_1(combined_tensor) combined_tensor = self.map_res_2(combined_tensor) combined_tensor = self.map_conv_2(combined_tensor) combined_tensor = self.map_res_3(combined_tensor) combined_tensor = self.map_conv_3(combined_tensor) combined_tensor = self.map_flat(combined_tensor) combined_tensor = self.map_lin(combined_tensor) combined_tensor = self.mixed_lin(combined_tensor) combined_tensor = self.mixed_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) combined_tensor = self.mixed_lin(combined_tensor) combined_tensor = self.mixed_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) # # Parameter and Sampling mu = self.mu(combined_tensor) logvar = self.logvar(combined_tensor) z = self.reparameterize(mu, logvar) return z, mu, logvar def generate_random(self, n=6): maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)] trajectories = [x.get_random_trajectory() for x in maps] trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories] trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2 trajectories = self._move_to_model_device(torch.stack(trajectories)) maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2 maps = self._move_to_model_device(torch.stack(maps)) labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n)) return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999) class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): name = 'CNNRouteGeneratorDiscriminated' def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing kld_loss /= reduce(mul, self.in_shape) loss = (kld_loss + discriminated_bce_loss) / 2 return dict(loss=loss, log=dict(loss=loss, discriminated_bce_loss=discriminated_bce_loss, kld_loss=kld_loss) ) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb, pred_label=pred_label, label=label, generated_alternative=generated_alternative) def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs: list): return self._test_val_epoch_end(outputs) def _test_val_epoch_end(self, outputs, test=False): evaluation = ROCEvaluation(plot_roc=True) pred_label = torch.cat([x['pred_label'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() # Sci-py call ROC eval call is eval(true_label, prediction) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) if test: # self.logger.log_metrics(score_dict) self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) plt.clf() maps, trajectories, labels, val_restul_dict = self.generate_random() from lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) def test_step(self, *args): return self._test_val_step(*args) def test_epoch_end(self, outputs): return self._test_val_epoch_end(outputs, test=True) @property def discriminator(self): if self._disc is None: raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)') return self._disc def set_discriminator(self, disc_model): if self._disc is not None: raise RuntimeError('Discriminator has already been set... What are trying to do?') self._disc = disc_model def __init__(self, *params): super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True) self._disc = None self.criterion = nn.BCELoss() self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', length=self.hparams.data_param.dataset_length, normalized=True)