adjustment fot CCS, notebook folder

This commit is contained in:
Steffen Illium 2021-03-22 16:43:18 +01:00
parent ed260f1c2a
commit d3e7bf7efb
3 changed files with 3 additions and 7 deletions

View File

@ -2,7 +2,6 @@ from itertools import cycle
import numpy as np import numpy as np
import torch import torch
from pytorch_lightning.metrics import Recall
from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \ from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \
recall_score recall_score

View File

@ -35,8 +35,9 @@ try:
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False) # self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True) # self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
# self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False) # self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False)
self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False) if self.n_classes > 2:
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False) self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False)
self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False)
def __iter__(self): def __iter__(self):
return iter(((name, metric) for name, metric in self._modules.items())) return iter(((name, metric) for name, metric in self._modules.items()))

View File

@ -71,10 +71,6 @@ def parse_comandline_args_add_defaults(filepath, overrides=None):
auto_lr_find=not args['debug'], auto_lr_find=not args['debug'],
weights_summary='top', weights_summary='top',
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1), check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1),
limit_train_batches = 2.0,
limit_val_batches = 2.0,
limit_test_batches = 2.0,
limit_predict_batches = 2.0,
) )
if overrides is not None and isinstance(overrides, (Mapping, Dict)): if overrides is not None and isinstance(overrides, (Mapping, Dict)):