bugs fixed, binary datasets working

This commit is contained in:
Steffen 2021-03-27 18:23:51 +01:00
parent 1d1b154460
commit abe870d106

View File

@ -28,9 +28,11 @@ try:
self.n_classes = n_classes self.n_classes = n_classes
self.tag = tag self.tag = tag
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False) self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False,)
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False) self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False,
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False) is_multiclass=True)
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False,
is_multiclass=True)
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False) self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False) # self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True) # self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
@ -46,14 +48,14 @@ try:
for _, metric in self: for _, metric in self:
try: try:
if self.n_classes <= 2: if self.n_classes <= 2:
metric.update(preds.unsqueeze(-1), target.unsqueeze(-1)) metric.update(preds, target)
else: else:
metric.update(preds, target) metric.update(preds, target)
except ValueError: except ValueError:
print(f'error was: {ValueError}') print(f'error was: {ValueError}')
print(f'Metric is: {metric}') print(f'Metric is: {metric}')
print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}') print(f'Shape is: preds - {preds.squeeze().shape}, target - {target.shape}')
metric.update(preds.unsqueeze(-1), target) metric.update(preds.squeeze(), target)
except AssertionError: except AssertionError:
print(f'error was: {AssertionError}') print(f'error was: {AssertionError}')
print(f'Metric is: {metric}') print(f'Metric is: {metric}')