adjustment fot CCS, notebook folder

This commit is contained in:
Steffen Illium
2021-03-22 16:43:19 +01:00
parent 78b3139d1a
commit c12f3866c8
6 changed files with 156 additions and 29 deletions

View File

@ -3,6 +3,7 @@ from argparse import Namespace
from torch import nn
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import LinearModule
from ml_lib.modules.model_parts import CNNEncoder
@ -36,9 +37,11 @@ class CNNBaseline(CombinedModelMixins,
# Modules with Parameters
self.encoder = CNNEncoder(in_shape=self.in_shape, **self.params.module_kwargs)
# Make Decision between binary and Multiclass Classification
logits = n_classes if n_classes > 2 else 1
module_kwargs = self.params.module_kwargs
module_kwargs.update(activation=nn.Softmax)
self.classifier = LinearModule(self.encoder.shape, n_classes, **module_kwargs)
module_kwargs.update(activation=(nn.Softmax if logits > 1 else nn.Sigmoid))
self.classifier = LinearModule(self.encoder.shape, logits, **module_kwargs)
def forward(self, x, mask=None, return_attn_weights=False):
"""
@ -52,4 +55,7 @@ class CNNBaseline(CombinedModelMixins,
return Namespace(main_out=tensor)
def additional_scores(self, outputs):
return MultiClassScores(self)(outputs)
if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)