paper preperations and notebooks, optuna callbacks
This commit is contained in:
@ -3,8 +3,6 @@ from argparse import Namespace
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ml_lib.metrics.binary_class_classifictaion import BinaryScores
|
||||
from ml_lib.metrics.multi_class_classification import MultiClassScores
|
||||
from ml_lib.modules.blocks import LinearModule
|
||||
from ml_lib.modules.model_parts import CNNEncoder
|
||||
from ml_lib.modules.util import (LightningBaseModule)
|
||||
@ -52,9 +50,3 @@ class CNNBaseline(CombinedModelMixins,
|
||||
|
||||
tensor = self.classifier(tensor)
|
||||
return Namespace(main_out=tensor)
|
||||
|
||||
def additional_scores(self, outputs):
|
||||
if self.params.n_classes > 2:
|
||||
return MultiClassScores(self)(outputs)
|
||||
else:
|
||||
return BinaryScores(self)(outputs)
|
||||
|
Reference in New Issue
Block a user