eval written
This commit is contained in:
@ -1,29 +1,24 @@
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
|
||||
|
||||
class ROCEvaluation(object):
|
||||
|
||||
BINARY_PROBLEM = 2
|
||||
linewidth = 2
|
||||
|
||||
def __init__(self, save_fig=True):
|
||||
def __init__(self, prepare_figure=False):
|
||||
self.prepare_figure = prepare_figure
|
||||
self.epoch = 0
|
||||
pass
|
||||
|
||||
def __call__(self, prediction, label, prepare_fig=True):
|
||||
def __call__(self, prediction, label, plotting=False):
|
||||
|
||||
# Compute ROC curve and ROC area
|
||||
fpr, tpr, _ = roc_curve(prediction, label)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
if prepare_fig:
|
||||
fig = self._prepare_fig()
|
||||
fig.plot(fpr, tpr, color='darkorange',
|
||||
lw=2, label=f'ROC curve (area = {roc_auc})')
|
||||
self._prepare_fig()
|
||||
return roc_auc
|
||||
if plotting:
|
||||
fig = plt.gcf()
|
||||
fig.plot(fpr, tpr, color='darkorange', lw=self.linewidth, label=f'ROC curve (area = {roc_auc})')
|
||||
return roc_auc, fpr, tpr
|
||||
|
||||
def _prepare_fig(self):
|
||||
fig = plt.gcf()
|
||||
@ -32,6 +27,6 @@ class ROCEvaluation(object):
|
||||
fig.ylim([0.0, 1.05])
|
||||
fig.xlabel('False Positive Rate')
|
||||
fig.ylabel('True Positive Rate')
|
||||
|
||||
fig.legend(loc="lower right")
|
||||
|
||||
return fig
|
||||
|
@ -1,4 +1,4 @@
|
||||
from dataset.dataset import TrajPairData
|
||||
from datasets.paired_dataset import TrajPairData
|
||||
from lib.modules.blocks import ConvModule
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
|
||||
|
@ -5,8 +5,10 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
from lib.modules.blocks import ConvModule, ResidualModule
|
||||
|
||||
@ -24,6 +26,22 @@ class ConvHomDetector(LightningBaseModule):
|
||||
loss = F.binary_cross_entropy(pred_y, batch_y.float())
|
||||
return {'loss': loss, 'log': dict(loss=loss)}
|
||||
|
||||
def test_step(self, batch_xy, **kwargs):
|
||||
batch_x, batch_y = batch_xy
|
||||
pred_y = self(batch_x)
|
||||
return dict(prediction=pred_y, label=batch_y)
|
||||
|
||||
def test_end(self, outputs):
|
||||
evaluation = ROCEvaluation()
|
||||
predictions = torch.stack([x['prediction'] for x in outputs])
|
||||
labels = torch.stack([x['label'] for x in outputs])
|
||||
|
||||
scores = evaluation(predictions.numpy(), labels.numpy())
|
||||
self.logger.log_metrics()
|
||||
|
||||
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(ConvHomDetector, self).__init__(*params)
|
||||
|
||||
@ -70,6 +88,26 @@ class ConvHomDetector(LightningBaseModule):
|
||||
self.classifier = nn.Linear(self.hparams.model_param.classes * 10, 1) # self.hparams.model_param.classes)
|
||||
self.out_activation = nn.Sigmoid() # nn.Softmax
|
||||
|
||||
# Dataloaders
|
||||
# ================================================================================
|
||||
# Train Dataloader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Test Dataloader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self.map_conv_0(x)
|
||||
tensor = self.map_res_1(tensor)
|
||||
|
@ -105,24 +105,6 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
torch.save(self.__class__, f)
|
||||
return True
|
||||
|
||||
@pl.data_loader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.train_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.test_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||
batch_size=self.hparams.data_param.batchsize,
|
||||
num_workers=self.hparams.data_param.worker)
|
||||
|
||||
@property
|
||||
def data_len(self):
|
||||
return len(self.dataset.train_dataset)
|
||||
|
@ -79,20 +79,37 @@ class Config(ConfigParser):
|
||||
super(Config, self).__init__(**kwargs)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _sort_combined_section_key_mapping(dict_obj):
|
||||
sorted_dict = defaultdict(dict)
|
||||
for key in dict_obj:
|
||||
section, *attr_name = key.split('_')
|
||||
attr_name = '_'.join(attr_name)
|
||||
value = str(dict_obj[key])
|
||||
|
||||
sorted_dict[section][attr_name] = value
|
||||
# noinspection PyTypeChecker
|
||||
return dict(sorted_dict)
|
||||
|
||||
@classmethod
|
||||
def read_namespace(cls, namespace: Namespace):
|
||||
|
||||
space_dict = defaultdict(dict)
|
||||
for key in namespace.__dict__:
|
||||
section, *attr_name = key.split('_')
|
||||
attr_name = '_'.join(attr_name)
|
||||
value = str(namespace.__getattribute__(key))
|
||||
|
||||
space_dict[section][attr_name] = value
|
||||
sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
|
||||
new_config = cls()
|
||||
new_config.read_dict(space_dict)
|
||||
new_config.read_dict(sorted_dict)
|
||||
return new_config
|
||||
|
||||
def update(self, mapping):
|
||||
sorted_dict = self._sort_combined_section_key_mapping(mapping)
|
||||
for section in sorted_dict:
|
||||
if self.has_section(section):
|
||||
pass
|
||||
else:
|
||||
self.add_section(section)
|
||||
for option, value in sorted_dict[section].items():
|
||||
self.set(section, option, value)
|
||||
return self
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
item = super(Config, self).get(*args, **kwargs)
|
||||
try:
|
||||
@ -108,5 +125,4 @@ class Config(ConfigParser):
|
||||
|
||||
with path.open('w') as configfile:
|
||||
super().write(configfile)
|
||||
|
||||
return True
|
||||
|
@ -78,6 +78,14 @@ class Logger(LightningLoggerBase):
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, **kwargs):
|
||||
self.neptunelogger.log_image(name, image, **kwargs)
|
||||
image.savefig(self.log_dir / name)
|
||||
|
||||
def save(self):
|
||||
self.testtubelogger.save()
|
||||
self.neptunelogger.save()
|
||||
|
Reference in New Issue
Block a user