diff --git a/utils/config.py b/utils/config.py index 6745947..d67bccd 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,14 +1,11 @@ import ast +from abc import ABC from argparse import Namespace from collections import defaultdict from configparser import ConfigParser from pathlib import Path -from ml_lib.models.generators.cnn import CNNRouteGeneratorModel -from ml_lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated - -from ml_lib.models.homotopy_classification.cnn_based import ConvHomDetector from ml_lib.utils.model_io import ModelParameters from ml_lib.utils.transforms import AsArray @@ -22,22 +19,34 @@ def is_jsonable(x): return False -class Config(ConfigParser): +class Config(ConfigParser, ABC): # TODO: Do this programmatically; This did not work: # Initialize Default Sections # for section in self.default_sections: # self.__setattr__(section, property(lambda x :x._get_namespace_for_section(section)) + @property + def model_map(self): + """ + This is function is supposed to return a dict, which holds a mapping from string model names to model classes + + Example: + from models.binary_classifier import BinaryClassifier + return dict(BinaryClassifier=BinaryClassifier, + ) + :return: + """ + + raise NotImplementedError + @property def model_class(self): - model_dict = dict(BinaryClassifier=BinaryClassifier, - ) try: - return model_dict[self.get('model', 'type')] + return self.model_map[self.get('model', 'type')] except KeyError as e: raise KeyError(rf'The model alias you provided ("{self.get("model", "type")}") does not exist! \n' - f'Try one of these:\n{list(model_dict.keys())}') + f'Try one of these:\n{list(self.model_map.keys())}') @property def main(self):