88 lines
2.7 KiB
Python
88 lines
2.7 KiB
Python
import importlib
|
|
import inspect
|
|
import pickle
|
|
import shelve
|
|
from argparse import ArgumentParser, ArgumentError
|
|
from ast import literal_eval
|
|
from pathlib import Path, PurePath
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import random
|
|
|
|
|
|
def auto_cast(a):
|
|
try:
|
|
return literal_eval(a)
|
|
except:
|
|
return a
|
|
|
|
|
|
def to_one_hot(idx_array, max_classes):
|
|
one_hot = np.zeros((idx_array.size, max_classes))
|
|
one_hot[np.arange(idx_array.size), idx_array] = 1
|
|
return one_hot
|
|
|
|
|
|
def fix_all_random_seeds(seed):
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
random.seed(seed)
|
|
print(f'Seed is now fixed: "{seed}".')
|
|
|
|
|
|
def write_to_shelve(file_path, value):
|
|
check_path(file_path)
|
|
file_path.parent.mkdir(exist_ok=True, parents=True)
|
|
with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
|
new_key = str(len(f))
|
|
f[new_key] = value
|
|
f.close()
|
|
|
|
|
|
def load_from_shelve(file_path, key):
|
|
check_path(file_path)
|
|
with shelve.open(str(file_path)) as d:
|
|
return d[key]
|
|
|
|
|
|
def check_path(file_path):
|
|
assert isinstance(file_path, Path)
|
|
assert str(file_path).endswith('.pik')
|
|
|
|
|
|
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
|
"""Locate an object by name or dotted path, importing as necessary."""
|
|
folder_path = Path(folder_path)
|
|
module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
|
for module_path in module_paths:
|
|
mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts]))
|
|
try:
|
|
model_class = mod.__getattribute__(class_name)
|
|
return model_class
|
|
except AttributeError:
|
|
continue
|
|
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
|
|
|
|
|
|
def add_argparse_args(cls, parent_parser):
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
|
full_arg_spec = inspect.getfullargspec(cls.__init__)
|
|
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
|
|
for idx, argument in enumerate(full_arg_spec.args):
|
|
try:
|
|
if argument == 'self':
|
|
continue
|
|
if idx < n_non_defaults:
|
|
parser.add_argument(f'--{argument}', type=int)
|
|
else:
|
|
argument_type = type(argument)
|
|
parser.add_argument(f'--{argument}',
|
|
type=argument_type,
|
|
default=full_arg_spec.defaults[idx - n_non_defaults]
|
|
)
|
|
except ArgumentError:
|
|
continue
|
|
return parser
|