adjustment fot CCS, notebook folder
This commit is contained in:
@ -3,6 +3,7 @@ from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import LinearModule
|
||||
from ml_lib.modules.model_parts import CNNEncoder
|
||||
@ -36,9 +37,11 @@ class CNNBaseline(CombinedModelMixins,
|
||||
# Modules with Parameters
|
||||
self.encoder = CNNEncoder(in_shape=self.in_shape, **self.params.module_kwargs)
|
||||
|
||||
# Make Decision between binary and Multiclass Classification
|
||||
logits = n_classes if n_classes > 2 else 1
|
||||
module_kwargs = self.params.module_kwargs
|
||||
module_kwargs.update(activation=nn.Softmax)
|
||||
self.classifier = LinearModule(self.encoder.shape, n_classes, **module_kwargs)
|
||||
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
|
||||
self.classifier = LinearModule(self.encoder.shape, logits, **module_kwargs)
|
||||
|
||||
def forward(self, x, mask=None, return_attn_weights=False):
|
||||
"""
|
||||
@ -52,4 +55,7 @@ class CNNBaseline(CombinedModelMixins,
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
return MultiClassScores(self)(outputs)
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
Reference in New Issue
Block a user