validation written

This commit is contained in:
Si11ium
2020-03-09 17:17:43 +01:00
parent e7ccfb7947
commit 4ae333fe5d
5 changed files with 133 additions and 24 deletions

View File

@@ -102,7 +102,7 @@ class TrajData(object):
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files))
equal_split = int(self.length // len(map_files)) or 1
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))

View File

@@ -1,3 +1,5 @@
from random import choice
import torch
from functools import reduce
from operator import mul
@@ -6,9 +8,12 @@ from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
class CNNRouteGeneratorModel(LightningBaseModule):
@@ -33,14 +38,54 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= self.in_shape
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def test_step(self, *args, **kwargs):
pass
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs):
evaluation = ROCEvaluation(plot_roc=True)
predictions = torch.cat([x['prediction'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
losses = torch.cat([x['discriminated_bce_loss'] for x in outputs]).unsqueeze(1)
mean_losses = losses.mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output_E{self.current_epoch}', fig)
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
@property
def discriminator(self):
@@ -57,12 +102,14 @@ class CNNRouteGeneratorModel(LightningBaseModule):
super(CNNRouteGeneratorModel, self).__init__(*params)
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route')
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
length=self.hparams.train_param.batch_size * 1000)
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters
self.feature_dim = 10
self.lat_dim = self.feature_dim + self.feature_dim + 1
self._disc = None
# NN Nodes
@@ -70,6 +117,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
#
# Utils
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.criterion = nn.MSELoss()
#
@@ -113,8 +161,8 @@ class CNNRouteGeneratorModel(LightningBaseModule):
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.feature_dim + self.feature_dim + 1, self.hparams.model_param.lat_dim)
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
#
# Alternative Generator
@@ -139,6 +187,32 @@ class CNNRouteGeneratorModel(LightningBaseModule):
#
# Encode
z, mu, logvar = self.encode(map_array, trajectory, label)
#
# Generate
alt_tensor = self.generate(z)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def generate(self, z):
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
def encode(self, map_array, trajectory, label):
map_tensor = self.map_conv_0(map_array)
map_tensor = self.map_res_1(map_tensor)
map_tensor = self.map_conv_1(map_tensor)
@@ -157,27 +231,19 @@ class CNNRouteGeneratorModel(LightningBaseModule):
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
mixed_tensor = self.relu(mixed_tensor)
mixed_tensor = self.mixed_lin(mixed_tensor)
mixed_tensor = self.relu(mixed_tensor)
#
# Parameter and Sampling
mu = self.mu(mixed_tensor)
logvar = self.logvar(mixed_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
#
# Generate
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def generate_random(self, n=6):
maps = [self.map_storage[choice(self.map_storage.keys())] for _ in range(n)]
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
maps = torch.stack([x.as_2d_array for x in maps] * 2)
labels = torch.as_tensor([0] * n + [1] * n)
return maps, trajectories, labels, self._test_val_step(maps, trajectories, labels)

View File

@@ -146,6 +146,7 @@ class Map(object):
img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=self.white, fill=self.white)
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.black, 1, 0)).sum()

View File

@@ -0,0 +1,43 @@
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axisartist.axes_grid import ImageGrid
from tqdm import tqdm
from typing import List
class GeneratorVisualizer(object):
def __init__(self, maps, trajectories, labels, val_result_dict):
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self.generated_alternatives = val_result_dict['generated_alternative']
self.pred_labels = val_result_dict['pred_label']
self.labels = labels
self.trajectories = trajectories
self.maps = maps
self.column_dict_list = self._build_column_dict_list()
def _build_column_dict_list(self):
dict_list = []
for idx in range(self.maps):
image = self.maps[idx] + self.trajectories[idx] + self.generated_alternatives
label = self.labels[idx]
dict_list.append(dict(image=image, label=label))
half_size = int(len(dict_list) // 2)
return dict_list[:half_size], dict_list[half_size:]
def draw(self):
fig = plt.figure()
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(len(self.column_dict_list[0]), len(self.column_dict_list)),
axes_pad=0.2, # pad between axes in inch.
)
for idx in grid.axes_all:
row, col = divmod(idx, len(self.column_dict_list))
current_image = self.column_dict_list[col]['image'][row]
current_label = self.column_dict_list[col]['label'][row]
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
fig.cbar_mode = 'single'
return fig

View File

@@ -33,7 +33,6 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_batchsize", type=int, default=100, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")