bugs fixed, binary datasets working
This commit is contained in:
parent
1d1b154460
commit
abe870d106
@ -28,9 +28,11 @@ try:
|
|||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
self.tag = tag
|
self.tag = tag
|
||||||
|
|
||||||
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False)
|
self.accuracy_score = pl.metrics.Accuracy(compute_on_step=False,)
|
||||||
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False)
|
self.precision = pl.metrics.Precision(num_classes=self.n_classes, average='macro', compute_on_step=False,
|
||||||
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False)
|
is_multiclass=True)
|
||||||
|
self.recall = pl.metrics.Recall(num_classes=self.n_classes, average='macro', compute_on_step=False,
|
||||||
|
is_multiclass=True)
|
||||||
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
|
self.confusion_matrix = pl.metrics.ConfusionMatrix(self.n_classes, normalize='true', compute_on_step=False)
|
||||||
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
|
# self.precision_recall_curve = pl.metrics.PrecisionRecallCurve(self.n_classes, compute_on_step=False)
|
||||||
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
|
# self.average_prec = pl.metrics.AveragePrecision(self.n_classes, compute_on_step=True)
|
||||||
@ -46,14 +48,14 @@ try:
|
|||||||
for _, metric in self:
|
for _, metric in self:
|
||||||
try:
|
try:
|
||||||
if self.n_classes <= 2:
|
if self.n_classes <= 2:
|
||||||
metric.update(preds.unsqueeze(-1), target.unsqueeze(-1))
|
metric.update(preds, target)
|
||||||
else:
|
else:
|
||||||
metric.update(preds, target)
|
metric.update(preds, target)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(f'error was: {ValueError}')
|
print(f'error was: {ValueError}')
|
||||||
print(f'Metric is: {metric}')
|
print(f'Metric is: {metric}')
|
||||||
print(f'Shape is: preds - {preds.unsqueeze(-1).shape}, target - {target.shape}')
|
print(f'Shape is: preds - {preds.squeeze().shape}, target - {target.shape}')
|
||||||
metric.update(preds.unsqueeze(-1), target)
|
metric.update(preds.squeeze(), target)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
print(f'error was: {AssertionError}')
|
print(f'error was: {AssertionError}')
|
||||||
print(f'Metric is: {metric}')
|
print(f'Metric is: {metric}')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user