from random import choice import torch from functools import reduce from operator import mul from torch import nn from torch.optim import Adam from datasets.trajectory_dataset import TrajData from ml_lib.evaluation.classification import ROCEvaluation from ml_lib.modules.blocks import ConvModule, ResidualModule, DeConvModule from ml_lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt class CNNRouteGeneratorModel(LightningBaseModule): name = 'CNNRouteGenerator' def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.train_param.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, alternative = batch_xy generated_alternative, z, mu, logvar = self(batch_x) element_wise_loss = self.criterion(generated_alternative, alternative) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing TODO: Does This make sense? Sanity Check it! # kld_loss /= reduce(mul, self.in_shape) # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 loss = (kld_loss + element_wise_loss) / 2 return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, alternative = batch_xy map_array = batch_x[0] trajectory = batch_x[1] label = batch_x[2].max() z, _, _ = self.encode(batch_x) generated_alternative = self.generate(z) return dict(map_array=map_array, trajectory=trajectory, batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1, alternative=alternative ) def _test_val_epoch_end(self, outputs, test=False): maps, trajectories, labels, val_restul_dict = self.generate_random() from ml_lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) plt.clf() return dict(epoch=self.current_epoch) def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs: list): return self._test_val_epoch_end(outputs) def test_step(self, *args): return self._test_val_step(*args) def test_epoch_end(self, outputs): return self._test_val_epoch_end(outputs, test=True) def __init__(self, *params, issubclassed=False): super(CNNRouteGeneratorModel, self).__init__(*params) if not issubclassed: # Dataset self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map', length=self.hparams.data_param.dataset_length, normalized=True) self.criterion = nn.MSELoss() # Additional Attributes # ####################################################### self.map_shape = self.dataset.map_shapes_max self.trajectory_features = 4 self.res_net = self.hparams.model_param.use_res_net self.lat_dim = self.hparams.model_param.lat_dim self.feature_dim = self.lat_dim * 10 ######################################################## # NN Nodes ################################################### # # Utils self.activation = nn.ReLU() self.sigmoid = nn.Sigmoid() # # Map Encoder self.enc_conv_0 = ConvModule(self.map_shape, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1, conv_padding=3, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) # Trajectory Encoder self.env_gru_1 = nn.GRU(input_size=self.trajectory_features, hidden_size=self.feature_dim, num_layers=3, batch_first=True) self.enc_flat = Flatten(self.enc_conv_3b.shape) self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim) # # Mixed Encoder self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim) self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x # # Variational Bottleneck self.mu = nn.Linear(self.feature_dim, self.lat_dim) self.logvar = nn.Linear(self.feature_dim, self.lat_dim) # # Alternative Generator self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape) self.gen_gru_x = nn.GRU(None, None, batch_first=True) def forward(self, batch_x): # # Encode z, mu, logvar = self.encode(batch_x) # # Generate alt_tensor = self.generate(z) return alt_tensor, z, mu, logvar @staticmethod def reparameterize(mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def encode(self, batch_x): combined_tensor = self.enc_conv_0(batch_x) combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor combined_tensor = self.enc_conv_1a(combined_tensor) combined_tensor = self.enc_conv_1b(combined_tensor) combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor combined_tensor = self.enc_conv_2a(combined_tensor) combined_tensor = self.enc_conv_2b(combined_tensor) combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor combined_tensor = self.enc_conv_3a(combined_tensor) combined_tensor = self.enc_conv_3b(combined_tensor) combined_tensor = self.enc_flat(combined_tensor) combined_tensor = self.enc_lin_1(combined_tensor) combined_tensor = self.enc_lin_2(combined_tensor) combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) combined_tensor = self.enc_lin_2(combined_tensor) combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) # # Parameter and Sampling mu = self.mu(combined_tensor) logvar = self.logvar(combined_tensor) z = self.reparameterize(mu, logvar) return z, mu, logvar def generate(self, z): alt_tensor = self.gen_lin_1(z) alt_tensor = self.activation(alt_tensor) alt_tensor = self.gen_lin_2(alt_tensor) alt_tensor = self.activation(alt_tensor) alt_tensor = self.reshape_to_last_conv(alt_tensor) alt_tensor = self.gen_deconv_1a(alt_tensor) alt_tensor = self.gen_deconv_1b(alt_tensor) alt_tensor = self.gen_deconv_2a(alt_tensor) alt_tensor = self.gen_deconv_2b(alt_tensor) alt_tensor = self.gen_deconv_3a(alt_tensor) alt_tensor = self.gen_deconv_3b(alt_tensor) alt_tensor = self.gen_deconv_out(alt_tensor) # alt_tensor = self.activation(alt_tensor) alt_tensor = self.sigmoid(alt_tensor) return alt_tensor def generate_random(self, n=6): maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)] trajectories = [x.get_random_trajectory() for x in maps] trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories] trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2 trajectories = self._move_to_model_device(torch.stack(trajectories)) maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2 maps = self._move_to_model_device(torch.stack(maps)) labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n)) return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999) class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): name = 'CNNRouteGeneratorDiscriminated' def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing kld_loss /= reduce(mul, self.in_shape) loss = (kld_loss + discriminated_bce_loss) / 2 return dict(loss=loss, log=dict(loss=loss, discriminated_bce_loss=discriminated_bce_loss, kld_loss=kld_loss) ) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb, pred_label=pred_label, label=label, generated_alternative=generated_alternative) def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs: list): return self._test_val_epoch_end(outputs) def _test_val_epoch_end(self, outputs, test=False): evaluation = ROCEvaluation(plot_roc=True) pred_label = torch.cat([x['pred_label'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() # Sci-py call ROC eval call is eval(true_label, prediction) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) if test: # self.logger.log_metrics(score_dict) self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) plt.clf() maps, trajectories, labels, val_restul_dict = self.generate_random() from ml_lib.visualization.generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) plt.clf() return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) def test_step(self, *args): return self._test_val_step(*args) def test_epoch_end(self, outputs): return self._test_val_epoch_end(outputs, test=True) @property def discriminator(self): if self._disc is None: raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)') return self._disc def set_discriminator(self, disc_model): if self._disc is not None: raise RuntimeError('Discriminator has already been set... What are trying to do?') self._disc = disc_model def __init__(self, *params): super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True) self._disc = None self.criterion = nn.BCELoss() self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', length=self.hparams.data_param.dataset_length, normalized=True)