Parameter Adjustmens and Ensemble Model Implementation
This commit is contained in:
@ -7,8 +7,7 @@ from argparse import Namespace, ArgumentParser
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
|
||||
from ml_lib.utils.model_io import ModelParameters
|
||||
import hashlib
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
@ -22,6 +21,30 @@ def is_jsonable(x):
|
||||
|
||||
class Config(ConfigParser, ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
short_name = "".join(c for c in self.model.type if c.isupper())
|
||||
return f'{short_name}_{self.fingerprint}'
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return f'version_{self.main.seed}'
|
||||
|
||||
@property
|
||||
def exp_path(self):
|
||||
return Path(self.train.outpath) / self.model.type / self.name
|
||||
|
||||
@property
|
||||
def fingerprint(self):
|
||||
h = hashlib.md5()
|
||||
params = deepcopy(self.as_dict)
|
||||
del params['model']['type']
|
||||
del params['data']['worker']
|
||||
del params['main']
|
||||
h.update(str(params).encode())
|
||||
fingerprint = h.hexdigest()
|
||||
return fingerprint
|
||||
|
||||
@property
|
||||
def _model_weight_init(self):
|
||||
mod = __import__('torch.nn.init', fromlist=[self.model.weight_init])
|
||||
@ -33,8 +56,8 @@ class Config(ConfigParser, ABC):
|
||||
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
||||
|
||||
Example:
|
||||
from models.binary_classifier import BinaryClassifier
|
||||
return dict(BinaryClassifier=BinaryClassifier,
|
||||
from models.binary_classifier import ConvClassifier
|
||||
return dict(ConvClassifier=ConvClassifier,
|
||||
)
|
||||
:return:
|
||||
"""
|
||||
@ -46,8 +69,7 @@ class Config(ConfigParser, ABC):
|
||||
try:
|
||||
return self._model_map[self.model.type]
|
||||
except KeyError:
|
||||
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
||||
f'Try one of these:\n{list(self._model_map.keys())}')
|
||||
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! Try one of these: {list(self._model_map.keys())}')
|
||||
|
||||
# TODO: Do this programmatically; This did not work:
|
||||
# Initialize Default Sections as Property
|
||||
@ -83,6 +105,7 @@ class Config(ConfigParser, ABC):
|
||||
params.update(self.train.__dict__)
|
||||
assert all(key not in list(params.keys()) for key in self.data.__dict__)
|
||||
params.update(self.data.__dict__)
|
||||
params.update(exp_path=str(self.exp_path), exp_fingerprint=str(self.fingerprint))
|
||||
return params
|
||||
|
||||
@property
|
||||
@ -134,7 +157,6 @@ class Config(ConfigParser, ABC):
|
||||
new_config.read_dict(sorted_dict)
|
||||
return new_config
|
||||
|
||||
|
||||
def build_model(self):
|
||||
return self.model_class(self.model_paramters)
|
||||
|
||||
|
@ -25,7 +25,7 @@ class Logger(LightningLoggerBase, ABC):
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.config.model.type
|
||||
return self.config.name
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
@ -37,7 +37,11 @@ class Logger(LightningLoggerBase, ABC):
|
||||
|
||||
@property
|
||||
def outpath(self):
|
||||
raise NotImplementedError
|
||||
return Path(self.config.train.outpath) / self.config.model.type
|
||||
|
||||
@property
|
||||
def exp_path(self):
|
||||
return Path(self.outpath) / self.name
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""
|
||||
@ -58,10 +62,12 @@ class Logger(LightningLoggerBase, ABC):
|
||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
experiment_name=self.name,
|
||||
project_name=self.project_name,
|
||||
upload_source_files=list())
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
||||
self.log_config_as_ini()
|
||||
|
||||
def log_hyperparams(self, params):
|
||||
self.neptunelogger.log_hyperparams(params)
|
||||
@ -80,6 +86,10 @@ class Logger(LightningLoggerBase, ABC):
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir / 'config.ini')
|
||||
|
||||
def log_text(self, name, text, step_nb=0, **kwargs):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text, step_nb)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
@ -97,7 +107,6 @@ class Logger(LightningLoggerBase, ABC):
|
||||
def finalize(self, status):
|
||||
self.testtubelogger.finalize(status)
|
||||
self.neptunelogger.finalize(status)
|
||||
self.log_config_as_ini()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
8
utils/transforms.py
Normal file
8
utils/transforms.py
Normal file
@ -0,0 +1,8 @@
|
||||
from torchvision.transforms import ToTensor as TorchvisionToTensor
|
||||
|
||||
|
||||
class ToTensor(TorchvisionToTensor):
|
||||
|
||||
def __call__(self, pic):
|
||||
tensor = super(ToTensor, self).__call__(pic).float()
|
||||
return tensor
|
Reference in New Issue
Block a user