117 lines
4.5 KiB
Python
117 lines
4.5 KiB
Python
from random import choices, seed
|
|
import numpy as np
|
|
|
|
import torch
|
|
from functools import reduce
|
|
from operator import mul
|
|
|
|
from torch import nn
|
|
from torch.optim import Adam
|
|
|
|
from datasets.trajectory_dataset import TrajData
|
|
from lib.evaluation.classification import ROCEvaluation
|
|
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
|
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
|
from lib.modules.utils import LightningBaseModule, Flatten
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
|
|
|
name = 'CNNRouteGeneratorDiscriminated'
|
|
|
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
|
batch_x, label = batch_xy
|
|
|
|
generated_alternative, z, mu, logvar = self(batch_x)
|
|
map_array, trajectory = batch_x
|
|
|
|
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
|
pred_label = self.discriminator(map_stack)
|
|
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
|
|
|
# see Appendix B from VAE paper:
|
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
|
# https://arxiv.org/abs/1312.6114
|
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
# Dimensional Resizing
|
|
kld_loss /= reduce(mul, self.in_shape)
|
|
|
|
loss = (kld_loss + discriminated_bce_loss) / 2
|
|
return dict(loss=loss, log=dict(loss=loss,
|
|
discriminated_bce_loss=discriminated_bce_loss,
|
|
kld_loss=kld_loss)
|
|
)
|
|
|
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
|
batch_x, label = batch_xy
|
|
|
|
generated_alternative, z, mu, logvar = self(batch_x)
|
|
map_array, trajectory = batch_x
|
|
|
|
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
|
pred_label = self.discriminator(map_stack)
|
|
|
|
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
|
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
|
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
|
|
|
def validation_step(self, *args):
|
|
return self._test_val_step(*args)
|
|
|
|
def validation_epoch_end(self, outputs: list):
|
|
return self._test_val_epoch_end(outputs)
|
|
|
|
def _test_val_epoch_end(self, outputs, test=False):
|
|
evaluation = ROCEvaluation(plot_roc=True)
|
|
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
|
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
|
|
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
|
if test:
|
|
# self.logger.log_metrics(score_dict)
|
|
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
|
plt.clf()
|
|
|
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
|
|
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
|
fig = g.draw()
|
|
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
|
plt.clf()
|
|
|
|
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
|
|
|
def test_step(self, *args):
|
|
return self._test_val_step(*args)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
return self._test_val_epoch_end(outputs, test=True)
|
|
|
|
@property
|
|
def discriminator(self):
|
|
if self._disc is None:
|
|
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
|
return self._disc
|
|
|
|
def set_discriminator(self, disc_model):
|
|
if self._disc is not None:
|
|
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
|
self._disc = disc_model
|
|
|
|
def __init__(self, *params):
|
|
raise NotImplementedError
|
|
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
|
|
|
self._disc = None
|
|
|
|
self.criterion = nn.BCELoss()
|
|
|
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
|
|
length=self.hparams.data_param.dataset_length, normalized=True)
|