Debugging Validation and testing

This commit is contained in:
Si11ium
2020-03-09 19:18:22 +01:00
parent 4ae333fe5d
commit 6b9696c98e
14 changed files with 28 additions and 116 deletions

View File

@@ -1,3 +1,5 @@
from statistics import mean
from random import choice
import torch
@@ -65,13 +67,12 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def validation_epoch_end(self, outputs):
evaluation = ROCEvaluation(plot_roc=True)
predictions = torch.cat([x['prediction'] for x in outputs])
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
losses = torch.cat([x['discriminated_bce_loss'] for x in outputs]).unsqueeze(1)
mean_losses = losses.mean()
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
plt.clf()
@@ -103,7 +104,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
length=self.hparams.train_param.batch_size * 1000)
length=self.hparams.data_param.dataset_length)
# Additional Attributes
self.in_shape = self.dataset.map_shapes_max
@@ -159,6 +160,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
#
# Mixed Encoder
self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
#
# Variational Bottleneck
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
@@ -242,7 +247,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
return z, mu, logvar
def generate_random(self, n=6):
maps = [self.map_storage[choice(self.map_storage.keys())] for _ in range(n)]
maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
maps = torch.stack([x.as_2d_array for x in maps] * 2)
labels = torch.as_tensor([0] * n + [1] * n)

View File

@@ -57,7 +57,8 @@ class ConvHomDetector(LightningBaseModule):
# Model Parameters
self.in_shape = self.dataset.map_shapes_max
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
self.criterion = nn.BCEWithLogitsLoss()
self.criterion = nn.BCELoss()
self.sigmoid = nn.Sigmoid()
# NN Nodes
# ============================
@@ -100,4 +101,5 @@ class ConvHomDetector(LightningBaseModule):
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.sigmoid(tensor)
return tensor

View File

@@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading
# =============================================================================
# Map Object
self.map_storage = MapStorage(self.hparams.data_param.map_root)
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
def size(self):
return self.shape
@@ -143,19 +143,19 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
batch_size=self.hparams.data_param.batchsize,
batch_size=self.hparams.train_param.batch_size,
num_workers=self.hparams.data_param.worker)

View File

@@ -167,6 +167,10 @@ class Map(object):
class MapStorage(object):
@property
def keys(self):
return list(self.data.keys())
def __init__(self, map_root, load_all=False):
self.data = dict()
self.map_root = Path(map_root)
@@ -175,11 +179,11 @@ class MapStorage(object):
_ = self[map_file.name]
def __getitem__(self, item):
if item in hasattr(self, item):
return self.__getattribute__(item)
if item in self.data.keys():
return self.data.get(item)
else:
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
self.__setattr__(item, d['map']['map'])
current_map = Map().from_image(self.map_root / item)
self.data.__setitem__(item, np.asarray(current_map))
return self[item]