paper preperations and notebooks, optuna callbacks

This commit is contained in:
Steffen Illium
2021-04-02 08:45:11 +02:00
parent 7c88602776
commit cec3a07d60
21 changed files with 3818 additions and 1059 deletions

View File

@ -3,8 +3,6 @@ from argparse import Namespace
from torch import nn
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
from ml_lib.metrics.multi_class_classification import MultiClassScores
from ml_lib.modules.blocks import LinearModule
from ml_lib.modules.model_parts import CNNEncoder
from ml_lib.modules.util import (LightningBaseModule)
@ -52,9 +50,3 @@ class CNNBaseline(CombinedModelMixins,
tensor = self.classifier(tensor)
return Namespace(main_out=tensor)
def additional_scores(self, outputs):
if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)