Debugging Validation and testing
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from statistics import mean
|
||||
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
@@ -65,13 +67,12 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
predictions = torch.cat([x['prediction'] for x in outputs])
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
losses = torch.cat([x['discriminated_bce_loss'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = losses.mean()
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve_E{self.current_epoch}', plt.gcf())
|
||||
plt.clf()
|
||||
@@ -103,7 +104,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||
length=self.hparams.train_param.batch_size * 1000)
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
|
||||
# Additional Attributes
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
@@ -159,6 +160,10 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
||||
|
||||
#
|
||||
# Mixed Encoder
|
||||
self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
||||
@@ -242,7 +247,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
return z, mu, logvar
|
||||
|
||||
def generate_random(self, n=6):
|
||||
maps = [self.map_storage[choice(self.map_storage.keys())] for _ in range(n)]
|
||||
maps = [self.map_storage[choice(self.map_storage.keys)] for _ in range(n)]
|
||||
trajectories = torch.stack([x.get_random_trajectory() for x in maps] * 2)
|
||||
maps = torch.stack([x.as_2d_array for x in maps] * 2)
|
||||
labels = torch.as_tensor([0] * n + [1] * n)
|
||||
|
||||
@@ -57,7 +57,8 @@ class ConvHomDetector(LightningBaseModule):
|
||||
# Model Parameters
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||
self.criterion = nn.BCEWithLogitsLoss()
|
||||
self.criterion = nn.BCELoss()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
# NN Nodes
|
||||
# ============================
|
||||
@@ -100,4 +101,5 @@ class ConvHomDetector(LightningBaseModule):
|
||||
tensor = self.flatten(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.classifier(tensor)
|
||||
tensor = self.sigmoid(tensor)
|
||||
return tensor
|
||||
|
||||
@@ -90,7 +90,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
# Map Object
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root, load_all=True)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
@@ -143,19 +143,19 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
batch_size=self.hparams.train_param.batch_size,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
|
||||
|
||||
@@ -167,6 +167,10 @@ class Map(object):
|
||||
|
||||
class MapStorage(object):
|
||||
|
||||
@property
|
||||
def keys(self):
|
||||
return list(self.data.keys())
|
||||
|
||||
def __init__(self, map_root, load_all=False):
|
||||
self.data = dict()
|
||||
self.map_root = Path(map_root)
|
||||
@@ -175,11 +179,11 @@ class MapStorage(object):
|
||||
_ = self[map_file.name]
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item in hasattr(self, item):
|
||||
return self.__getattribute__(item)
|
||||
if item in self.data.keys():
|
||||
return self.data.get(item)
|
||||
else:
|
||||
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
|
||||
self.__setattr__(item, d['map']['map'])
|
||||
current_map = Map().from_image(self.map_root / item)
|
||||
self.data.__setitem__(item, np.asarray(current_map))
|
||||
return self[item]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user