Model Training

This commit is contained in:
Si11ium
2020-05-03 18:00:49 +02:00
parent 3e75d73a6b
commit 6d8fbd7184
4 changed files with 80 additions and 52 deletions

View File

@ -1,5 +1,6 @@
from argparse import Namespace
from collections import Mapping
from copy import deepcopy
from pathlib import Path
import torch
@ -8,12 +9,12 @@ from torch import nn
# Hyperparamter Object
class ModelParameters(Mapping, Namespace):
class ModelParameters(Namespace, Mapping):
@property
def module_paramters(self):
paramter_mapping = dict()
paramter_mapping.update(self.model_param.__dict__)
def module_kwargs(self):
paramter_mapping = deepcopy(self.__dict__)
paramter_mapping.update(
dict(
@ -21,26 +22,37 @@ class ModelParameters(Mapping, Namespace):
)
)
del paramter_mapping['in_shape']
return paramter_mapping
@property
def test_activation(self):
try:
return self._activations[self.model.activation]
except KeyError:
return nn.ReLU
def __getitem__(self, k):
# k: _KT -> _VT_co
return self.__dict__[k]
def __len__(self):
# -> int
return len(self.__dict__.keys())
return len(self.__dict__)
def __iter__(self):
# -> Iterator[_T_co]
return iter(list(self.__dict__.keys()))
def __delitem__(self, key):
self.__dict__.__delitem__(key)
self.__delattr__(key)
return True
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation']]
else:
return super(ModelParameters, self).__getattribute__(name)
_activations = dict(
leaky_relu=nn.LeakyReLU,
relu=nn.ReLU,
@ -48,22 +60,8 @@ class ModelParameters(Mapping, Namespace):
tanh=nn.Tanh
)
def __init__(self, model_param, train_param, data_param):
self.model_param = model_param
self.train_param = train_param
self.data_param = data_param
kwargs = vars(model_param)
kwargs.update(vars(train_param))
kwargs.update(vars(data_param))
super(ModelParameters, self).__init__(**kwargs)
def __getattribute__(self, item):
if item == 'activation':
try:
return self._activations[item]
except KeyError:
return nn.ReLU
return super(ModelParameters, self).__getattribute__(item)
def __init__(self, parameter_mapping):
super(ModelParameters, self).__init__(**parameter_mapping)
class SavedLightningModels(object):