import torch from functools import reduce from operator import mul from torch import nn from datasets.trajectory_dataset import TrajData from ml_lib.evaluation.classification import ROCEvaluation from models.generators.cnn import CNNRouteGeneratorModel import matplotlib.pyplot as plt class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): name = 'CNNRouteGeneratorDiscriminated' def training_step(self, batch_xy, batch_nb, *args, **kwargs): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing kld_loss /= reduce(mul, self.in_shape) loss = (kld_loss + discriminated_bce_loss) / 2 return dict(loss=loss, log=dict(loss=loss, discriminated_bce_loss=discriminated_bce_loss, kld_loss=kld_loss) ) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, label = batch_xy generated_alternative, z, mu, logvar = self(batch_x) map_array, trajectory = batch_x map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) pred_label = self.discriminator(map_stack) discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb, pred_label=pred_label, label=label, generated_alternative=generated_alternative) def validation_step(self, *args): return self._test_val_step(*args) def validation_epoch_end(self, outputs: list): return self._test_val_epoch_end(outputs) def _test_val_epoch_end(self, outputs, test=False): evaluation = ROCEvaluation(plot_roc=True) pred_label = torch.cat([x['pred_label'] for x in outputs]) labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() # Sci-py call ROC eval call is eval(true_label, prediction) roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) if test: # self.logger.log_metrics(score_dict) self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) plt.clf() maps, trajectories, labels, val_restul_dict = self.generate_random() from generator_eval import GeneratorVisualizer g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) plt.clf() return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) def test_step(self, *args): return self._test_val_step(*args) def test_epoch_end(self, outputs): return self._test_val_epoch_end(outputs, test=True) @property def discriminator(self): if self._disc is None: raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)') return self._disc def set_discriminator(self, disc_model): if self._disc is not None: raise RuntimeError('Discriminator has already been set... What are trying to do?') self._disc = disc_model def __init__(self, *params): raise NotImplementedError super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True) self._disc = None self.criterion = nn.BCELoss() self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True, length=self.hparams.data_param.dataset_length, normalized=True)