Debugging und hparam als Mapping

This commit is contained in:
Si11ium 2020-04-24 17:55:42 +02:00
parent f5c240f038
commit 8497857a57
2 changed files with 15 additions and 3 deletions

View File

@ -7,7 +7,6 @@ from configparser import ConfigParser
from pathlib import Path
from ml_lib.utils.model_io import ModelParameters
from ml_lib.utils.transforms import AsArray
def is_jsonable(x):
@ -116,7 +115,7 @@ class Config(ConfigParser, ABC):
def build_model(self):
return self.model_class(self.model_paramters)
def build_and_init_model(self, in_shape, weight_init_function):
def build_and_init_model(self, weight_init_function):
model = self.build_model()
model.init_weights(weight_init_function)
return model

View File

@ -1,4 +1,5 @@
from argparse import Namespace
from collections import Mapping
from pathlib import Path
import torch
@ -7,7 +8,19 @@ from torch import nn
# Hyperparamter Object
class ModelParameters(Namespace):
class ModelParameters(Mapping, Namespace):
def __getitem__(self, k):
# k: _KT -> _VT_co
return self.__dict__[k]
def __len__(self):
# -> int
return len(self.__dict__.keys())
def __iter__(self):
# -> Iterator[_T_co]
return iter(list(self.__dict__.keys()))
_activations = dict(
leaky_relu=nn.LeakyReLU,