Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:08 +01:00
parent b5e3e5aec1
commit f89f0f8528
14 changed files with 349 additions and 80 deletions

View File

@ -1,7 +1,5 @@
import inspect
from argparse import ArgumentParser
from functools import reduce
from matplotlib import pyplot as plt
from abc import ABC
from pathlib import Path
@ -12,14 +10,77 @@ from pytorch_lightning.utilities import argparse_utils
from torch import nn
from torch.nn import functional as F, Unfold
from sklearn.metrics import ConfusionMatrixDisplay
# Utility - Modules
###################
from ..utils.model_io import ModelParameters
from ..utils.tools import locate_and_import_class, add_argparse_args
from ..utils.tools import add_argparse_args
try:
import pytorch_lightning as pl
class PLMetrics(pl.metrics.Metric):
def __init__(self, n_classes, tag=''):
super(PLMetrics, self).__init__()
self.n_classes = n_classes
self.tag = tag
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False)
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False)
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False)
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
def __iter__(self):
return iter(((name, metric) for name, metric in self._modules.items()))
def update(self, preds, target) -> None:
for _, metric in self:
metric.update(preds, target)
def reset(self) -> None:
for _, metric in self:
metric.reset()
def compute(self) -> dict:
tag = f'{self.tag}_' if self.tag else ''
return {f'{tag}{metric_name}_score': metric.compute() for metric_name, metric in self}
def compute_and_prepare(self):
pl_metrics = self.compute()
images_from_metrics = dict()
for metric_name in list(pl_metrics.keys()):
if 'curve' in metric_name:
continue
roc_curve = pl_metrics.pop(metric_name)
print('debug_point')
elif 'matrix' in metric_name:
matrix = pl_metrics.pop(metric_name)
fig1, ax1 = plt.subplots(dpi=96)
disp = ConfusionMatrixDisplay(confusion_matrix=matrix.cpu().numpy(),
display_labels=[i for i in range(self.n_classes)]
)
disp.plot(include_values=True, ax=ax1)
images_from_metrics[metric_name] = fig1
elif 'ROC' in metric_name:
continue
roc = pl_metrics.pop(metric_name)
print('debug_point')
else:
pl_metrics[metric_name] = pl_metrics[metric_name].cpu().item()
return pl_metrics, images_from_metrics
class LightningBaseModule(pl.LightningModule, ABC):
@classmethod
@ -49,6 +110,9 @@ try:
self._weight_init = weight_init
self.params = ModelParameters(model_parameters)
self.metrics = PLMetrics(self.params.n_classes, tag='PL')
pass
def size(self):
return self.shape