initial commit
This commit is contained in:
135
utils/config.py
Normal file
135
utils/config.py
Normal file
@ -0,0 +1,135 @@
|
||||
import ast
|
||||
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
|
||||
from ml_lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||
from ml_lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
|
||||
|
||||
from ml_lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||
from ml_lib.utils.model_io import ModelParameters
|
||||
from ml_lib.utils.transforms import AsArray
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
import json
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
class Config(ConfigParser):
|
||||
|
||||
# TODO: Do this programmatically; This did not work:
|
||||
# Initialize Default Sections
|
||||
# for section in self.default_sections:
|
||||
# self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section))
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
model_dict = dict(ConvHomDetector=ConvHomDetector,
|
||||
CNNRouteGenerator=CNNRouteGeneratorModel,
|
||||
CNNRouteGeneratorDiscriminated=CNNRouteGeneratorDiscriminated
|
||||
)
|
||||
try:
|
||||
return model_dict[self.get('model', 'type')]
|
||||
except KeyError as e:
|
||||
raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n'
|
||||
f'Try one of these:\n{list(model_dict.keys())}')
|
||||
|
||||
@property
|
||||
def main(self):
|
||||
return self._get_namespace_for_section('main')
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self._get_namespace_for_section('model')
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
return self._get_namespace_for_section('train')
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._get_namespace_for_section('data')
|
||||
|
||||
@property
|
||||
def project(self):
|
||||
return self._get_namespace_for_section('project')
|
||||
###################################################
|
||||
|
||||
@property
|
||||
def model_paramters(self):
|
||||
return ModelParameters(self.model, self.train, self.data)
|
||||
|
||||
@property
|
||||
def tags(self, ):
|
||||
return [f'{key}: {val}' for key, val in self.serializable.items()]
|
||||
|
||||
@property
|
||||
def serializable(self):
|
||||
return {f'{section}_{key}': val for section, params in self._sections.items()
|
||||
for key, val in params.items() if is_jsonable(val)}
|
||||
|
||||
@property
|
||||
def as_dict(self):
|
||||
return self._sections
|
||||
|
||||
def _get_namespace_for_section(self, item):
|
||||
return Namespace(**{key: self.get(item, key) for key in self[item]})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(Config, self).__init__(**kwargs)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _sort_combined_section_key_mapping(dict_obj):
|
||||
sorted_dict = defaultdict(dict)
|
||||
for key in dict_obj:
|
||||
section, *attr_name = key.split('_')
|
||||
attr_name = '_'.join(attr_name)
|
||||
value = str(dict_obj[key])
|
||||
|
||||
sorted_dict[section][attr_name] = value
|
||||
# noinspection PyTypeChecker
|
||||
return dict(sorted_dict)
|
||||
|
||||
@classmethod
|
||||
def read_namespace(cls, namespace: Namespace):
|
||||
|
||||
sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__)
|
||||
new_config = cls()
|
||||
new_config.read_dict(sorted_dict)
|
||||
return new_config
|
||||
|
||||
def update(self, mapping):
|
||||
sorted_dict = self._sort_combined_section_key_mapping(mapping)
|
||||
for section in sorted_dict:
|
||||
if self.has_section(section):
|
||||
pass
|
||||
else:
|
||||
self.add_section(section)
|
||||
for option, value in sorted_dict[section].items():
|
||||
self.set(section, option, value)
|
||||
return self
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
item = super(Config, self).get(*args, **kwargs)
|
||||
try:
|
||||
return ast.literal_eval(item)
|
||||
except SyntaxError:
|
||||
return item
|
||||
except ValueError:
|
||||
return item
|
||||
|
||||
def write(self, filepath, **kwargs):
|
||||
path = Path(filepath, exist_ok=True)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with path.open('w') as configfile:
|
||||
super().write(configfile)
|
||||
return True
|
Reference in New Issue
Block a user