paper preperations and notebooks, optuna callbacks

This commit is contained in:
Steffen Illium
2021-04-02 08:45:11 +02:00
parent abe870d106
commit faa27c3cf9
5 changed files with 26 additions and 15 deletions

View File

@ -190,10 +190,6 @@ class BaseCNNEncoder(ShapeMixin, nn.Module):
kernels = kernels if not isinstance(kernels, int) else [kernels] * len(filters)
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
# Optional Padding for odd image-sizes
# Obsolet, cdan be done by autopadding module on incoming tensors
# in_shape = [tensor+1 if tensor % 2 != 0 and idx else tensor for idx, tensor in enumerate(in_shape)]
# Parameters
self.lat_dim = lat_dim
self.in_shape = in_shape

View File

@ -14,6 +14,8 @@ from sklearn.metrics import ConfusionMatrixDisplay
# Utility - Modules
###################
from ..metrics.binary_class_classifictaion import BinaryScores
from ..metrics.multi_class_classification import MultiClassScores
from ..utils.model_io import ModelParameters
from ..utils.tools import add_argparse_args
@ -133,9 +135,6 @@ try:
def size(self):
return self.shape
def additional_scores(self, outputs):
raise NotImplementedError
def save_to_disk(self, model_path):
Path(model_path, exist_ok=True).mkdir(parents=True, exist_ok=True)
if not (model_path / 'model_class.obj').exists():
@ -174,6 +173,12 @@ try:
weight_initializer = WeightInit(in_place_init_function=self._weight_init)
self.apply(weight_initializer)
def additional_scores(self, outputs):
if self.params.n_classes > 2:
return MultiClassScores(self)(outputs)
else:
return BinaryScores(self)(outputs)
module_types = (LightningBaseModule, nn.Module,)
except ImportError: