eval written

This commit is contained in:
Si11ium
2020-03-05 16:58:23 +01:00
parent 8d06c179c9
commit 1f25bf599b
12 changed files with 127 additions and 74 deletions

View File

@ -1,29 +1,24 @@
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
class ROCEvaluation(object):
BINARY_PROBLEM = 2
linewidth = 2
def __init__(self, save_fig=True):
def __init__(self, prepare_figure=False):
self.prepare_figure = prepare_figure
self.epoch = 0
pass
def __call__(self, prediction, label, prepare_fig=True):
def __call__(self, prediction, label, plotting=False):
# Compute ROC curve and ROC area
fpr, tpr, _ = roc_curve(prediction, label)
roc_auc = auc(fpr, tpr)
if prepare_fig:
fig = self._prepare_fig()
fig.plot(fpr, tpr, color='darkorange',
lw=2, label=f'ROC curve (area = {roc_auc})')
self._prepare_fig()
return roc_auc
if plotting:
fig = plt.gcf()
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
return roc_auc, fpr, tpr
def _prepare_fig(self):
fig = plt.gcf()
@ -32,6 +27,6 @@ class ROCEvaluation(object):
fig.ylim([0.0, 1.05])
fig.xlabel('False Positive Rate')
fig.ylabel('True Positive Rate')
fig.legend(loc="lower right")
return fig

View File

@ -1,4 +1,4 @@
from dataset.dataset import TrajPairData
from datasets.paired_dataset import TrajPairData
from lib.modules.blocks import ConvModule
from lib.modules.utils import LightningBaseModule

View File

@ -5,8 +5,10 @@ import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.utils import LightningBaseModule, Flatten
from lib.modules.blocks import ConvModule, ResidualModule
@ -24,6 +26,22 @@ class ConvHomDetector(LightningBaseModule):
loss = F.binary_cross_entropy(pred_y, batch_y.float())
return {'loss': loss, 'log': dict(loss=loss)}
def test_step(self, batch_xy, **kwargs):
batch_x, batch_y = batch_xy
pred_y = self(batch_x)
return dict(prediction=pred_y, label=batch_y)
def test_end(self, outputs):
evaluation = ROCEvaluation()
predictions = torch.stack([x['prediction'] for x in outputs])
labels = torch.stack([x['label'] for x in outputs])
scores = evaluation(predictions.numpy(), labels.numpy())
self.logger.log_metrics()
pass
def __init__(self, *params):
super(ConvHomDetector, self).__init__(*params)
@ -70,6 +88,26 @@ class ConvHomDetector(LightningBaseModule):
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
self.out_activation = nn.Sigmoid() # nn.Softmax
# Dataloaders
# ================================================================================
# Train Dataloader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Test Dataloader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
# Validation Dataloader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
def forward(self, x):
tensor = self.map_conv_0(x)
tensor = self.map_res_1(tensor)

View File

@ -105,24 +105,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
torch.save(self.__class__, f)
return True
@pl.data_loader
def train_dataloader(self):
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def test_dataloader(self):
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@pl.data_loader
def val_dataloader(self):
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
batch_size=self.hparams.data_param.batchsize,
num_workers=self.hparams.data_param.worker)
@property
def data_len(self):
return len(self.dataset.train_dataset)

View File

@ -79,20 +79,37 @@ class Config(ConfigParser):
super(Config, self).__init__(**kwargs)
pass
@staticmethod
def _sort_combined_section_key_mapping(dict_obj):
sorted_dict = defaultdict(dict)
for key in dict_obj:
section, *attr_name = key.split('_')
attr_name = '_'.join(attr_name)
value = str(dict_obj[key])
sorted_dict[section][attr_name] = value
# noinspection PyTypeChecker
return dict(sorted_dict)
@classmethod
def read_namespace(cls, namespace: Namespace):
space_dict = defaultdict(dict)
for key in namespace.__dict__:
section, *attr_name = key.split('_')
attr_name = '_'.join(attr_name)
value = str(namespace.__getattribute__(key))
space_dict[section][attr_name] = value
sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
new_config = cls()
new_config.read_dict(space_dict)
new_config.read_dict(sorted_dict)
return new_config
def update(self, mapping):
sorted_dict = self._sort_combined_section_key_mapping(mapping)
for section in sorted_dict:
if self.has_section(section):
pass
else:
self.add_section(section)
for option, value in sorted_dict[section].items():
self.set(section, option, value)
return self
def get(self, *args, **kwargs):
item = super(Config, self).get(*args, **kwargs)
try:
@ -108,5 +125,4 @@ class Config(ConfigParser):
with path.open('w') as configfile:
super().write(configfile)
return True

View File

@ -78,6 +78,14 @@ class Logger(LightningLoggerBase):
def log_config_as_ini(self):
self.config.write(self.log_dir)
def log_metric(self, metric_name, metric_value, **kwargs):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, **kwargs):
self.neptunelogger.log_image(name, image, **kwargs)
image.savefig(self.log_dir / name)
def save(self):
self.testtubelogger.save()
self.neptunelogger.save()