From abe870d106beeccd5c17261da641971ac16c2f4b Mon Sep 17 00:00:00 2001 From: Steffen Date: Sat, 27 Mar 2021 18:23:51 +0100 Subject: [PATCH] bugs fixed, binary datasets working --- modules/util.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/modules/util.py b/modules/util.py index 0dd9847..71b7e61 100644 --- a/modules/util.py +++ b/modules/util.py @@ -28,9 +28,11 @@ try: self.n_classes = n_classes self.tag = tag - self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False) - self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False) - self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False) + self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False,) + self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False, + is_multiclass=True) + self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False, + is_multiclass=True) self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False) # self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False) # self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True) @@ -46,14 +48,14 @@ try: for _, metric in self: try: if self.n_classes <= 2: - metric.update(preds.unsqueeze(-1), target.unsqueeze(-1)) + metric.update(preds, target) else: metric.update(preds, target) except ValueError: print(f'error was: {ValueError}') print(f'Metric is: {metric}') - print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}') - metric.update(preds.unsqueeze(-1), target) + print(f'Shape is: preds - {preds.squeeze().shape}, target - {target.shape}') + metric.update(preds.squeeze(), target) except AssertionError: print(f'error was: {AssertionError}') print(f'Metric is: {metric}')