import ast import configparser from distutils.util import strtobool from pathlib import Path from typing import Mapping, Dict import torch from copy import deepcopy from abc import ABC from argparse import Namespace, ArgumentParser from collections import defaultdict from configparser import ConfigParser, DuplicateSectionError import hashlib from pytorch_lightning import Trainer from ml_lib.utils.loggers import LightningLogger from ml_lib.utils.tools import locate_and_import_class, auto_cast # Argument Parser and default Values # ============================================================================= def parse_comandline_args_add_defaults(filepath, overrides=None): # Parse Command Line parser = ArgumentParser() parser.add_argument('--model_name', type=str) parser.add_argument('--data_name', type=str) parser.add_argument('--seed', type=str) parser.add_argument('--debug', type=strtobool) # Load Defaults from _parameters.ini file config = configparser.ConfigParser() config.read(str(filepath)) new_defaults = dict() for key in ['project', 'train', 'data']: defaults = config[key] new_defaults.update({key: auto_cast(val) for key, val in defaults.items()}) args, _ = parser.parse_known_args() overrides = overrides or dict() default_data = overrides.get('data_name', None) or new_defaults['data_name'] default_model = overrides.get('model_name', None) or new_defaults['model_name'] default_seed = overrides.get('seed', None) or new_defaults['seed'] data_name = args.__dict__.get('data_name', None) or default_data model_name = args.__dict__.get('model_name', None) or default_model found_seed = args.__dict__.get('seed', None) or default_seed new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()}) found_data_class = locate_and_import_class(data_name, 'datasets') found_model_class = locate_and_import_class(model_name, 'models') for module in [LightningLogger, Trainer, found_data_class, found_model_class]: parser = module.add_argparse_args(parser) args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults)) args = vars(args) args.update({key: auto_cast(val) for key, val in args.items()}) args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None, row_log_interval=1000, # TODO: Better Value / Setting log_save_interval=10000, # TODO: Better Value / Setting weights_summary='top', ) if overrides is not None and isinstance(overrides, (Mapping, Dict)): args.update(**overrides) if args['debug']: args.update( # The seems to be the new "fast_dev_run" val_check_interval=1, max_epochs=2, max_steps=2, auto_lr_find=False, check_val_every_n_epoch=1 ) return args, found_data_class, found_model_class, found_seed def is_jsonable(x): import json try: json.dumps(x) return True except TypeError: return False class Config(ConfigParser, ABC): @property def name(self): short_name = "".join(c for c in self.model.type if c.isupper()) return f'{short_name}_{self.fingerprint}' @property def version(self): return f'version_{self.main.seed}' @property def exp_path(self): return Path(self.train.outpath) / self.model.type / self.name @property def fingerprint(self): h = hashlib.md5() params = deepcopy(self.as_dict) try: del params['model']['type'] except KeyError: pass try: del params['data']['worker'] except KeyError: pass try: del params['data']['refresh'] except KeyError: pass try: del params['main'] except KeyError: pass try: del params['project'] except KeyError: pass # Flatten the dict of dicts for section in list(params.keys()): params.update({f'{section}_{key}': val for key, val in params[section].items()}) del params[section] _, vals = zip(*sorted(params.items(), key=lambda tup: tup[0])) h.update(str(vals).encode()) fingerprint = h.hexdigest() return fingerprint @property def _model_weight_init(self): mod = __import__('torch.nn.init', fromlist=[self.model.weight_init]) return getattr(mod, self.model.weight_init) @property def _model_map(self): """ This is function is supposed to return a dict, which holds a mapping from string model names to model classes Example: from models.binary_classifier import ConvClassifier return dict(ConvClassifier=ConvClassifier, ) :return: """ raise NotImplementedError @property def model_class(self): try: return locate_and_import_class(self.model.type, folder_path='models') except AttributeError as e: raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}") ' + f'was not found!\n' + f'{e}') @property def data_class(self): try: return locate_and_import_class(self.data.class_name, folder_path='datasets') except AttributeError as e: raise AttributeError(f'The dataset alias you provided ("{self.get("data", "class_name")}") ' + f'was not found!\n' + f'{e}') # -------------------------------------------------- # TODO: Do this programmatically; This did not work: # Initialize Default Sections as Property # for section in self.default_sections: # self.__setattr__(section, property(lambda tensor :tensor._get_namespace_for_section(section)) @property def main(self): return self._get_namespace_for_section('main') @property def model(self): return self._get_namespace_for_section('model') @property def train(self): return self._get_namespace_for_section('train') @property def data(self): return self._get_namespace_for_section('data') @property def project(self): return self._get_namespace_for_section('project') ################################################### @property def model_paramters(self): params = deepcopy(self.model.__dict__) assert all(key not in list(params.keys()) for key in self.train.__dict__) params.update(self.train.__dict__) assert all(key not in list(params.keys()) for key in self.data.__dict__) params.update(self.data.__dict__) params.update(version=self.version) params.update(exp_path=str(self.exp_path), exp_fingerprint=str(self.fingerprint)) return params @property def tags(self, ): return [f'{key}: {val}' for key, val in self.serializable.items()] @property def serializable(self): return {f'{section}_{key}': val for section, params in self._sections.items() for key, val in params.items() if is_jsonable(val)} @property def as_dict(self): return self._sections def _get_namespace_for_section(self, item): return Namespace(**{key: self.get(item, key) for key in self[item]}) def __init__(self, **kwargs): super(Config, self).__init__(**kwargs) pass @staticmethod def _sort_combined_section_key_mapping(dict_obj): sorted_dict = defaultdict(dict) for key in dict_obj: section, *attr_name = key.split('_') attr_name = '_'.join(attr_name) value = str(dict_obj[key]) sorted_dict[section][attr_name] = value # noinspection PyTypeChecker return dict(sorted_dict) @classmethod def read_namespace(cls, namespace: Namespace): sorted_dict = cls._sort_combined_section_key_mapping(namespace.__dict__) new_config = cls() new_config.read_dict(sorted_dict) return new_config @classmethod def read_argparser(cls, argparser: ArgumentParser): # Parse it args = argparser.parse_args() sorted_dict = cls._sort_combined_section_key_mapping(args.__dict__) new_config = cls() new_config.read_dict(sorted_dict) return new_config def build_model(self): return self.model_class(self.model_paramters) def build_and_init_model(self): model = self.build_model() model.init_weights(self._model_weight_init) return model def update(self, mapping): sorted_dict = self._sort_combined_section_key_mapping(mapping) for section in sorted_dict: if self.has_section(section): pass else: self.add_section(section) for option, value in sorted_dict[section].items(): self.set(section, option, value) return self def get(self, *args, **kwargs): item = super(Config, self).get(*args, **kwargs) try: return ast.literal_eval(item) except SyntaxError: return item except ValueError: return item def write(self, filepath, **kwargs): path = Path(filepath, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True) with path.open('w') as configfile: super().write(configfile) return True def _write_section(self, fp, section_name, section_items, delimiter): if section_name == 'project': return else: super(Config, self)._write_section(fp, section_name, section_items, delimiter) def add_section(self, section: str) -> None: try: super(Config, self).add_section(section) except DuplicateSectionError: pass class DataClass(Namespace): @property def __dict__(self): return [x for x in dir(self) if not x.startswith('_')]