Model Training
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
from argparse import Namespace
|
||||
from collections import Mapping
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@ -8,12 +9,12 @@ from torch import nn
|
||||
|
||||
|
||||
# Hyperparamter Object
|
||||
class ModelParameters(Mapping, Namespace):
|
||||
class ModelParameters(Namespace, Mapping):
|
||||
|
||||
@property
|
||||
def module_paramters(self):
|
||||
paramter_mapping = dict()
|
||||
paramter_mapping.update(self.model_param.__dict__)
|
||||
def module_kwargs(self):
|
||||
|
||||
paramter_mapping = deepcopy(self.__dict__)
|
||||
|
||||
paramter_mapping.update(
|
||||
dict(
|
||||
@ -21,26 +22,37 @@ class ModelParameters(Mapping, Namespace):
|
||||
)
|
||||
)
|
||||
|
||||
del paramter_mapping['in_shape']
|
||||
|
||||
return paramter_mapping
|
||||
|
||||
@property
|
||||
def test_activation(self):
|
||||
try:
|
||||
return self._activations[self.model.activation]
|
||||
except KeyError:
|
||||
return nn.ReLU
|
||||
|
||||
def __getitem__(self, k):
|
||||
# k: _KT -> _VT_co
|
||||
return self.__dict__[k]
|
||||
|
||||
def __len__(self):
|
||||
# -> int
|
||||
return len(self.__dict__.keys())
|
||||
return len(self.__dict__)
|
||||
|
||||
def __iter__(self):
|
||||
# -> Iterator[_T_co]
|
||||
return iter(list(self.__dict__.keys()))
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.__dict__.__delitem__(key)
|
||||
self.__delattr__(key)
|
||||
return True
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name == 'activation':
|
||||
return self._activations[self['activation']]
|
||||
else:
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
|
||||
_activations = dict(
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
relu=nn.ReLU,
|
||||
@ -48,22 +60,8 @@ class ModelParameters(Mapping, Namespace):
|
||||
tanh=nn.Tanh
|
||||
)
|
||||
|
||||
def __init__(self, model_param, train_param, data_param):
|
||||
self.model_param = model_param
|
||||
self.train_param = train_param
|
||||
self.data_param = data_param
|
||||
kwargs = vars(model_param)
|
||||
kwargs.update(vars(train_param))
|
||||
kwargs.update(vars(data_param))
|
||||
super(ModelParameters, self).__init__(**kwargs)
|
||||
|
||||
def __getattribute__(self, item):
|
||||
if item == 'activation':
|
||||
try:
|
||||
return self._activations[item]
|
||||
except KeyError:
|
||||
return nn.ReLU
|
||||
return super(ModelParameters, self).__getattribute__(item)
|
||||
def __init__(self, parameter_mapping):
|
||||
super(ModelParameters, self).__init__(**parameter_mapping)
|
||||
|
||||
|
||||
class SavedLightningModels(object):
|
||||
|
Reference in New Issue
Block a user