from statistics import mean from random import choice import torch from functools import reduce from operator import mul from torch import nn from torch.optim import Adam from datasets.trajectory_dataset import TrajData from lib.evaluation.classification import ROCEvaluation from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule from lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt class CNNRouteGeneratorModel(LightningBaseModule): name = 'CNNRouteGenerator' def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.train_param.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x + [label, ]) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing kld_loss /= self.in_shape loss = (kld_loss + discriminated_bce_loss) / 2 return dict(loss=loss, log=dict(loss=loss, discriminated_bce_loss=discriminated_bce_loss, kld_loss=kld_loss) ) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x + [label, ]) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb, pred_label=pred_label, label=label, generated_alternative=generated_alternative) def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs): evaluation = ROCEvaluation(plot_roc=True) pred_label = torch.cat([x['pred_label'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() # Sci-py call ROC eval call is eval(true_label, prediction) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) # self.logger.log_metrics(score_dict) self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf()) plt.clf() maps, trajectories, labels, val_restul_dict = self.generate_random() from lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output_E{self.current_epoch}', fig) return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) def test_step(self, *args): return self._test_val_step(*args) @property def discriminator(self): if self._disc is None: raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)') return self._disc def set_discriminator(self, disc_model): if self._disc is not None: raise RuntimeError('Discriminator has already been set... What are trying to do?') self._disc = disc_model def __init__(self, *params): super(CNNRouteGeneratorModel, self).__init__(*params) # Dataset self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', length=self.hparams.data_param.dataset_length) # Additional Attributes self.in_shape = self.dataset.map_shapes_max # Todo: Better naming and size in Parameters self.feature_dim = 10 self.lat_dim = self.feature_dim + self.feature_dim + 1 self._disc = None # NN Nodes ################################################### # # Utils self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.criterion = nn.MSELoss() # # Map Encoder self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0]) self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0]) self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[1]) self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[1]) self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2]) self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[2]) self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2]*2) self.map_flat = Flatten(self.map_conv_3.shape) self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim) # # Trajectory Encoder self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[0]) self.traj_flat = Flatten(self.traj_conv_3.shape) self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim) # # Mixed Encoder self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim) # # Variational Bottleneck self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim) self.logvar = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim) # # Alternative Generator self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape)) self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape) self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2], conv_padding=0, conv_kernel=5, conv_stride=1) self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1], conv_padding=0, conv_kernel=3, conv_stride=1) self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0], conv_padding=1, conv_kernel=3, conv_stride=1) self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None, conv_padding=1, conv_kernel=3, conv_stride=1) def forward(self, batch_x): # # Sorting the Input map_array, trajectory, label = batch_x # # Encode z, mu, logvar = self.encode(map_array, trajectory, label) # # Generate alt_tensor = self.generate(z) return alt_tensor, z, mu, logvar @staticmethod def reparameterize(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def generate(self, z): alt_tensor = self.alt_lin_1(z) alt_tensor = self.alt_lin_2(alt_tensor) alt_tensor = self.reshape_to_map(alt_tensor) alt_tensor = self.alt_deconv_1(alt_tensor) alt_tensor = self.alt_deconv_2(alt_tensor) alt_tensor = self.alt_deconv_3(alt_tensor) alt_tensor = self.alt_deconv_out(alt_tensor) alt_tensor = self.sigmoid(alt_tensor) return alt_tensor def encode(self, map_array, trajectory, label): map_tensor = self.map_conv_0(map_array) map_tensor = self.map_res_1(map_tensor) map_tensor = self.map_conv_1(map_tensor) map_tensor = self.map_res_2(map_tensor) map_tensor = self.map_conv_2(map_tensor) map_tensor = self.map_res_3(map_tensor) map_tensor = self.map_conv_3(map_tensor) map_tensor = self.map_flat(map_tensor) map_tensor = self.map_lin(map_tensor) traj_tensor = self.traj_conv_1(trajectory) traj_tensor = self.traj_conv_2(traj_tensor) traj_tensor = self.traj_conv_3(traj_tensor) traj_tensor = self.traj_flat(traj_tensor) traj_tensor = self.traj_lin(traj_tensor) mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1) mixed_tensor = self.relu(mixed_tensor) mixed_tensor = self.mixed_lin(mixed_tensor) mixed_tensor = self.relu(mixed_tensor) # # Parameter and Sampling mu = self.mu(mixed_tensor) logvar = self.logvar(mixed_tensor) z = self.reparameterize(mu, logvar) return z, mu, logvar def generate_random(self, n=6): maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)] trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2) maps = torch.stack([x.as_2d_array for x in maps] * 2) labels = torch.as_tensor([0] * n + [1] * n) return maps, trajectories, labels, self._test_val_step(maps, trajectories, labels)