From d3e7bf7efbd333828d7894013852800b2e97eba3 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Mon, 22 Mar 2021 16:43:18 +0100 Subject: [PATCH] adjustment fot CCS, notebook folder --- metrics/multi_class_classification.py | 1 - modules/util.py | 5 +++-- utils/config.py | 4 ---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/metrics/multi_class_classification.py b/metrics/multi_class_classification.py index 6404821..d4391a9 100644 --- a/metrics/multi_class_classification.py +++ b/metrics/multi_class_classification.py @@ -2,7 +2,6 @@ from itertools import cycle import numpy as np import torch -from pytorch_lightning.metrics import Recall from sklearn.metrics import f1_score, roc_curve, auc, roc_auc_score, ConfusionMatrixDisplay, confusion_matrix, \ recall_score diff --git a/modules/util.py b/modules/util.py index 8b27fd0..edc6874 100644 --- a/modules/util.py +++ b/modules/util.py @@ -35,8 +35,9 @@ try: # self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False) # self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True) # self.roc = pl.metrics.ROC(self.n_classes, compute_on_step=False) - self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False) - self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False) + if self.n_classes > 2: + self.fbeta = pl.metrics.FBeta(self.n_classes, average='macro', compute_on_step=False) + self.f1 = pl.metrics.F1(self.n_classes, average='macro', compute_on_step=False) def __iter__(self): return iter(((name, metric) for name, metric in self._modules.items())) diff --git a/utils/config.py b/utils/config.py index 87b3b17..18ccc04 100644 --- a/utils/config.py +++ b/utils/config.py @@ -71,10 +71,6 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): auto_lr_find=not args['debug'], weights_summary='top', check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1), - limit_train_batches = 2.0, - limit_val_batches = 2.0, - limit_test_batches = 2.0, - limit_predict_batches = 2.0, ) if overrides is not None and isinstance(overrides, (Mapping, Dict)):