Model Training

This commit is contained in:
Si11ium
2020-05-03 18:00:49 +02:00
parent 3e75d73a6b
commit 6d8fbd7184
4 changed files with 80 additions and 52 deletions

View File

@ -1,7 +1,9 @@
import ast
from copy import deepcopy
from abc import ABC
from argparse import Namespace
from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
@ -20,13 +22,13 @@ def is_jsonable(x):
class Config(ConfigParser, ABC):
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
@property
def _model_weight_init(self):
mod = __import__('torch.nn.init', fromlist=[self.model.weight_init])
return getattr(mod, self.model.weight_init)
@property
def model_map(self):
def _model_map(self):
"""
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
@ -42,10 +44,15 @@ class Config(ConfigParser, ABC):
@property
def model_class(self):
try:
return self.model_map[self.get('model', 'type')]
except KeyError as e:
return self._model_map[self.model.type]
except KeyError:
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
f'Try one of these:\n{list(self.model_map.keys())}')
f'Try one of these:\n{list(self._model_map.keys())}')
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
@property
def main(self):
@ -71,7 +78,12 @@ class Config(ConfigParser, ABC):
@property
def model_paramters(self):
return ModelParameters(self.model, self.train, self.data)
params = deepcopy(self.model.__dict__)
assert all(key not in list(params.keys()) for key in self.train.__dict__)
params.update(self.train.__dict__)
assert all(key not in list(params.keys()) for key in self.data.__dict__)
params.update(self.data.__dict__)
return params
@property
def tags(self, ):
@ -113,12 +125,22 @@ class Config(ConfigParser, ABC):
new_config.read_dict(sorted_dict)
return new_config
@classmethod
def read_argparser(cls, argparser: ArgumentParser):
# Parse it
args = argparser.parse_args()
sorted_dict = cls._sort_combined_section_key_mapping(args.__dict__)
new_config = cls()
new_config.read_dict(sorted_dict)
return new_config
def build_model(self):
return self.model_class(self.model_paramters)
def build_and_init_model(self, weight_init_function):
def build_and_init_model(self):
model = self.build_model()
model.init_weights(weight_init_function)
model.init_weights(self._model_weight_init)
return model
def update(self, mapping):