train running dataset fixed

This commit is contained in:
steffen
2020-03-05 20:50:07 +01:00
parent 1f25bf599b
commit 05033bed75
11 changed files with 41 additions and 49 deletions

View File

@@ -11,6 +11,7 @@ from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule
import matplotlib.pyplot as plt
class ConvHomDetector(LightningBaseModule):
@@ -36,10 +37,9 @@ class ConvHomDetector(LightningBaseModule):
predictions = torch.stack([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs])
scores = evaluation(predictions.numpy(), labels.numpy())
self.logger.log_metrics()
scores = evaluation(predictions.numpy(), labels.numpy(), )
self.logger.log_metrics({key:value for key, value in zip(['roc_auc', 'tpr', 'fpr'], scores)})
self.logger.log_image(f'{self.name}', plt.gcf())
pass
def __init__(self, *params):
@@ -88,6 +88,19 @@ class ConvHomDetector(LightningBaseModule):
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
def forward(self, x):
tensor = self.map_conv_0(x)
tensor = self.map_res_1(tensor)
tensor = self.map_conv_1(tensor)
tensor = self.map_res_2(tensor)
tensor = self.map_conv_2(tensor)
tensor = self.map_conv_3(tensor)
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.out_activation(tensor)
return tensor
# Dataloaders
# ================================================================================
# Train Dataloader
@@ -107,16 +120,3 @@ class ConvHomDetector(LightningBaseModule):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
def forward(self, x):
tensor = self.map_conv_0(x)
tensor = self.map_res_1(tensor)
tensor = self.map_conv_1(tensor)
tensor = self.map_res_2(tensor)
tensor = self.map_conv_2(tensor)
tensor = self.map_conv_3(tensor)
tensor = self.flatten(tensor)
tensor = self.linear(tensor)
tensor = self.classifier(tensor)
tensor = self.out_activation(tensor)
return tensor

View File

@@ -45,7 +45,7 @@ class Map(object):
@property
def as_2d_array(self):
return self.map_array[1:]
return self.map_array.squeeze()
def __init__(self, name='', array_like_map_representation=None):
if array_like_map_representation is not None:
@@ -145,9 +145,9 @@ class Map(object):
img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=255, fill=255)
draw.polygon(polyline, outline=1, fill=1)
a = (np.asarray(img) * np.where(self.as_2d_array == self.white, 0, 1)).sum()
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum()
if a:
return False # Non-Homotoph
@@ -159,7 +159,7 @@ class Map(object):
# The standard colormaps also all have reversed versions.
# They have the same names with _r tacked on to the end.
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
img = ax.imshow(self.as_array, cmap='Greys_r')
img = ax.imshow(self.as_2d_array, cmap='Greys_r')
return dict(img=img, fig=fig, ax=ax)

View File

@@ -14,7 +14,7 @@ class Trajectory(object):
@property
def xy_vertices(self):
return [(x,y) for _, x,y in self._vertices]
return [(x, y) for _, y, x in self._vertices]
@property
def endpoints(self):
@@ -30,11 +30,11 @@ class Trajectory(object):
@property
def xs(self):
return [x[1] for x in self._vertices]
return [x[2] for x in self._vertices]
@property
def ys(self):
return [x[0] for x in self._vertices]
return [x[1] for x in self._vertices]
@property
def as_paired_list(self):
@@ -59,7 +59,7 @@ class Trajectory(object):
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative')
if highlights:
kwargs.update(marker='bo')
kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca()
img = plt.plot(self.xs, self.ys, **kwargs)
return dict(img=img, fig=fig, ax=ax)

View File

@@ -76,7 +76,7 @@ class Logger(LightningLoggerBase):
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir)
self.config.write(self.log_dir / 'config.ini')
def log_metric(self, metric_name, metric_value, **kwargs):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
@@ -91,8 +91,8 @@ class Logger(LightningLoggerBase):
self.neptunelogger.save()
def finalize(self, status):
self.testtubelogger.finalize()
self.neptunelogger.finalize()
self.testtubelogger.finalize(status)
self.neptunelogger.finalize(status)
self.log_config_as_ini()
def __enter__(self):