hom_traj_gen/lib/models/generators/cnn_discriminated.py
2020-03-13 21:52:33 +01:00

117 lines
4.5 KiB
Python

from random import choices, seed
import numpy as np
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
name = 'CNNRouteGeneratorDiscriminated'
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
raise NotImplementedError
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
self._disc = None
self.criterion = nn.BCELoss()
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
length=self.hparams.data_param.dataset_length, normalized=True)