bug in metric calculation

This commit is contained in:
Steffen Illium
2021-03-27 16:39:07 +01:00
parent 6816e423ff
commit 1d1b154460
5 changed files with 38 additions and 15 deletions

View File

@ -44,7 +44,21 @@ try:
def update(self, preds, target) -> None:
for _, metric in self:
metric.update(preds, target)
try:
if self.n_classes <= 2:
metric.update(preds.unsqueeze(-1), target.unsqueeze(-1))
else:
metric.update(preds, target)
except ValueError:
print(f'error was: {ValueError}')
print(f'Metric is: {metric}')
print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}')
metric.update(preds.unsqueeze(-1), target)
except AssertionError:
print(f'error was: {AssertionError}')
print(f'Metric is: {metric}')
print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}')
metric.update(preds, target.unsqueeze(-1))
def reset(self) -> None:
for _, metric in self: