bug in metric calculation
This commit is contained in:
@@ -16,8 +16,7 @@ class CNNBaseline(CombinedModelMixins,
|
||||
):
|
||||
|
||||
def __init__(self, in_shape, n_classes, weight_init, activation,
|
||||
use_bias, use_norm, dropout, lat_dim, features,
|
||||
filters,
|
||||
use_bias, use_norm, dropout, lat_dim, filters,
|
||||
lr, weight_decay, sto_weight_avg, lr_warm_restart_epochs, opt_reset_interval,
|
||||
loss, scheduler, lr_scheduler_parameter
|
||||
):
|
||||
|
||||
@@ -6,6 +6,7 @@ from torch import nn
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import (TransformerModule, F_x)
|
||||
from ml_lib.modules.util import (LightningBaseModule, AutoPadToShape)
|
||||
@@ -128,4 +129,7 @@ class VisualTransformer(CombinedModelMixins,
|
||||
return Namespace(main_out=tensor, attn_weights=attn_weights)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
if self.params.n_classes <= 2:
|
||||
return BinaryScores(self)(outputs)
|
||||
else:
|
||||
return MultiClassScores(self)(outputs)
|
||||
|
||||
Reference in New Issue
Block a user