Model Training
This commit is contained in:
@ -1,7 +1,9 @@
|
||||
import ast
|
||||
from copy import deepcopy
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from argparse import Namespace
|
||||
from argparse import Namespace, ArgumentParser
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
@ -20,13 +22,13 @@ def is_jsonable(x):
|
||||
|
||||
class Config(ConfigParser, ABC):
|
||||
|
||||
# TODO: Do this programmatically; This did not work:
|
||||
# Initialize Default Sections
|
||||
# for section in self.default_sections:
|
||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||
@property
|
||||
def _model_weight_init(self):
|
||||
mod = __import__('torch.nn.init', fromlist=[self.model.weight_init])
|
||||
return getattr(mod, self.model.weight_init)
|
||||
|
||||
@property
|
||||
def model_map(self):
|
||||
def _model_map(self):
|
||||
"""
|
||||
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
||||
|
||||
@ -42,10 +44,15 @@ class Config(ConfigParser, ABC):
|
||||
@property
|
||||
def model_class(self):
|
||||
try:
|
||||
return self.model_map[self.get('model', 'type')]
|
||||
except KeyError as e:
|
||||
return self._model_map[self.model.type]
|
||||
except KeyError:
|
||||
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
||||
f'Try one of these:\n{list(self.model_map.keys())}')
|
||||
f'Try one of these:\n{list(self._model_map.keys())}')
|
||||
|
||||
# TODO: Do this programmatically; This did not work:
|
||||
# Initialize Default Sections as Property
|
||||
# for section in self.default_sections:
|
||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||
|
||||
@property
|
||||
def main(self):
|
||||
@ -71,7 +78,12 @@ class Config(ConfigParser, ABC):
|
||||
|
||||
@property
|
||||
def model_paramters(self):
|
||||
return ModelParameters(self.model, self.train, self.data)
|
||||
params = deepcopy(self.model.__dict__)
|
||||
assert all(key not in list(params.keys()) for key in self.train.__dict__)
|
||||
params.update(self.train.__dict__)
|
||||
assert all(key not in list(params.keys()) for key in self.data.__dict__)
|
||||
params.update(self.data.__dict__)
|
||||
return params
|
||||
|
||||
@property
|
||||
def tags(self, ):
|
||||
@ -113,12 +125,22 @@ class Config(ConfigParser, ABC):
|
||||
new_config.read_dict(sorted_dict)
|
||||
return new_config
|
||||
|
||||
@classmethod
|
||||
def read_argparser(cls, argparser: ArgumentParser):
|
||||
# Parse it
|
||||
args = argparser.parse_args()
|
||||
sorted_dict = cls._sort_combined_section_key_mapping(args.__dict__)
|
||||
new_config = cls()
|
||||
new_config.read_dict(sorted_dict)
|
||||
return new_config
|
||||
|
||||
|
||||
def build_model(self):
|
||||
return self.model_class(self.model_paramters)
|
||||
|
||||
def build_and_init_model(self, weight_init_function):
|
||||
def build_and_init_model(self):
|
||||
model = self.build_model()
|
||||
model.init_weights(weight_init_function)
|
||||
model.init_weights(self._model_weight_init)
|
||||
return model
|
||||
|
||||
def update(self, mapping):
|
||||
|
Reference in New Issue
Block a user