Debugging und hparam als Mapping
This commit is contained in:
@ -7,7 +7,6 @@ from configparser import ConfigParser
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ml_lib.utils.model_io import ModelParameters
|
from ml_lib.utils.model_io import ModelParameters
|
||||||
from ml_lib.utils.transforms import AsArray
|
|
||||||
|
|
||||||
|
|
||||||
def is_jsonable(x):
|
def is_jsonable(x):
|
||||||
@ -116,7 +115,7 @@ class Config(ConfigParser, ABC):
|
|||||||
def build_model(self):
|
def build_model(self):
|
||||||
return self.model_class(self.model_paramters)
|
return self.model_class(self.model_paramters)
|
||||||
|
|
||||||
def build_and_init_model(self, in_shape, weight_init_function):
|
def build_and_init_model(self, weight_init_function):
|
||||||
model = self.build_model()
|
model = self.build_model()
|
||||||
model.init_weights(weight_init_function)
|
model.init_weights(weight_init_function)
|
||||||
return model
|
return model
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from collections import Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -7,7 +8,19 @@ from torch import nn
|
|||||||
|
|
||||||
|
|
||||||
# Hyperparamter Object
|
# Hyperparamter Object
|
||||||
class ModelParameters(Namespace):
|
class ModelParameters(Mapping, Namespace):
|
||||||
|
|
||||||
|
def __getitem__(self, k):
|
||||||
|
# k: _KT -> _VT_co
|
||||||
|
return self.__dict__[k]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# -> int
|
||||||
|
return len(self.__dict__.keys())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# -> Iterator[_T_co]
|
||||||
|
return iter(list(self.__dict__.keys()))
|
||||||
|
|
||||||
_activations = dict(
|
_activations = dict(
|
||||||
leaky_relu=nn.LeakyReLU,
|
leaky_relu=nn.LeakyReLU,
|
||||||
|
Reference in New Issue
Block a user