Parameter Adjustmens and Ensemble Model Implementation

This commit is contained in:
Si11ium
2020-05-08 16:30:54 +02:00
parent 3c776f13c5
commit d2e74ff33a
6 changed files with 126 additions and 56 deletions

@ -7,8 +7,7 @@ from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from ml_lib.utils.model_io import ModelParameters
import hashlib
def is_jsonable(x):
@ -22,6 +21,30 @@ def is_jsonable(x):
class Config(ConfigParser, ABC):
@property
def name(self):
short_name = "".join(c for c in self.model.type if c.isupper())
return f'{short_name}_{self.fingerprint}'
@property
def version(self):
return f'version_{self.main.seed}'
@property
def exp_path(self):
return Path(self.train.outpath) / self.model.type / self.name
@property
def fingerprint(self):
h = hashlib.md5()
params = deepcopy(self.as_dict)
del params['model']['type']
del params['data']['worker']
del params['main']
h.update(str(params).encode())
fingerprint = h.hexdigest()
return fingerprint
@property
def _model_weight_init(self):
mod = __import__('torch.nn.init', fromlist=[self.model.weight_init])
@ -33,8 +56,8 @@ class Config(ConfigParser, ABC):
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
Example:
from models.binary_classifier import BinaryClassifier
return dict(BinaryClassifier=BinaryClassifier,
from models.binary_classifier import ConvClassifier
return dict(ConvClassifier=ConvClassifier,
)
:return:
"""
@ -46,8 +69,7 @@ class Config(ConfigParser, ABC):
try:
return self._model_map[self.model.type]
except KeyError:
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
f'Try one of these:\n{list(self._model_map.keys())}')
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! Try one of these: {list(self._model_map.keys())}')
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
@ -83,6 +105,7 @@ class Config(ConfigParser, ABC):
params.update(self.train.__dict__)
assert all(key not in list(params.keys()) for key in self.data.__dict__)
params.update(self.data.__dict__)
params.update(exp_path=str(self.exp_path), exp_fingerprint=str(self.fingerprint))
return params
@property
@ -134,7 +157,6 @@ class Config(ConfigParser, ABC):
new_config.read_dict(sorted_dict)
return new_config
def build_model(self):
return self.model_class(self.model_paramters)