bug in metric calculation
This commit is contained in:
@ -44,7 +44,21 @@ try:
|
||||
|
||||
def update(self, preds, target) -> None:
|
||||
for _, metric in self:
|
||||
metric.update(preds, target)
|
||||
try:
|
||||
if self.n_classes <= 2:
|
||||
metric.update(preds.unsqueeze(-1), target.unsqueeze(-1))
|
||||
else:
|
||||
metric.update(preds, target)
|
||||
except ValueError:
|
||||
print(f'error was: {ValueError}')
|
||||
print(f'Metric is: {metric}')
|
||||
print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}')
|
||||
metric.update(preds.unsqueeze(-1), target)
|
||||
except AssertionError:
|
||||
print(f'error was: {AssertionError}')
|
||||
print(f'Metric is: {metric}')
|
||||
print(f'Shape is: preds - {preds.shape}, target - {target.unsqueeze(-1).shape}')
|
||||
metric.update(preds, target.unsqueeze(-1))
|
||||
|
||||
def reset(self) -> None:
|
||||
for _, metric in self:
|
||||
|
Reference in New Issue
Block a user