Model Loading by string. Within Debugging

This commit is contained in:
Si11ium
2020-08-15 12:42:57 +02:00
parent a4b6c698c3
commit 6bc9447ce1
5 changed files with 108 additions and 58 deletions

View File

@@ -7,9 +7,11 @@ from abc import ABC
from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser
from configparser import ConfigParser, DuplicateSectionError
import hashlib
from ml_lib.utils.tools import locate_and_import_class
def is_jsonable(x):
import json
@@ -90,11 +92,13 @@ class Config(ConfigParser, ABC):
@property
def model_class(self):
try:
return self._model_map[self.model.type]
except KeyError:
raise KeyError(f'The model alias you provided ("{self.get("model", "type")}")' +
'does not exist! Try one of these: {list(self._model_map.keys())}')
return locate_and_import_class(self.model.type)
except AttributeError as e:
raise AttributeError(f'The model alias you provided ("{self.get("model", "type")}")' +
f'was not found!\n' +
f'{e}')
# --------------------------------------------------
# TODO: Do this programmatically; This did not work:
# Initialize Default Sections as Property
# for section in self.default_sections:
@@ -223,3 +227,16 @@ class Config(ConfigParser, ABC):
return
else:
super(Config, self)._write_section(fp, section_name, section_items, delimiter)
def add_section(self, section: str) -> None:
try:
super(Config, self).add_section(section)
except DuplicateSectionError:
pass
class DataClass(Namespace):
@property
def __dict__(self):
return [x for x in dir(self) if not x.startswith('_')]

View File

@@ -1,6 +1,3 @@
import argparse
from typing import Union, Dict, Optional, Any
from abc import ABC
from pathlib import Path

View File

@@ -1,6 +1,9 @@
import importlib
import pickle
import shelve
from pathlib import Path
from pathlib import Path, PurePath
from pydoc import safeimport
from typing import Union
import numpy as np
import torch
@@ -37,3 +40,16 @@ def load_from_shelve(file_path, key):
def check_path(file_path):
assert isinstance(file_path, Path)
assert str(file_path).endswith('.pik')
def locate_and_import_class(class_name, models_location: Union[str, PurePath] = 'models', forceload=False):
"""Locate an object by name or dotted path, importing as necessary."""
models_location = Path(models_location)
module_paths = [x for x in models_location.rglob('*.py') if x.is_file() and '__init__' not in x.name]
for module_path in module_paths:
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
try:
model_class = mod.__getattribute__(class_name)
except AttributeError:
continue
return model_class