diff --git a/utils/config.py b/utils/config.py index 18ccc04..b80c68f 100644 --- a/utils/config.py +++ b/utils/config.py @@ -61,6 +61,9 @@ def parse_comandline_args_add_defaults(filepath, overrides=None): for module in [Logger, Trainer, found_data_class, found_model_class]: parser = module.add_argparse_args(parser) + # This is obsolete + # new_defaults.update(data_name=data_name, model_name=model_name) + args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults)) args = vars(args) diff --git a/utils/tools.py b/utils/tools.py index 08d8f3f..012a496 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -54,8 +54,12 @@ def check_path(file_path): def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''): """Locate an object by name or dotted path, importing as necessary.""" + import sys + sys.path.append("..") folder_path = Path(folder_path) module_paths = [x for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name] + # possible_package_path = folder_path / '__init__.py' + # package = str(possible_package_path) if possible_package_path.exists() else None for module_path in module_paths: mod = importlib.import_module('.'.join([x.replace('.py', '') for x in module_path.parts])) try: diff --git a/visualization/tools.py b/visualization/tools.py index 87079ba..8db6bbd 100644 --- a/visualization/tools.py +++ b/visualization/tools.py @@ -1,12 +1,38 @@ try: from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas + from matplotlib import pyplot as plt except ImportError: # pragma: no-cover - raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover + raise ImportError('You want to use `matplotlib` which is not installed yet,' # pragma: no-cover ' install it with `pip install matplotlib`.') from pathlib import Path +def prettyfy_sns(): + plt.style.use('default') + try: + import seaborn as sns + except ImportError: + raise ImportError('You want to use `seaborn` which is not installed yet,' # pragma: no-cover + ' install it with `pip install seaborn`.') + sns.set_palette('Dark2') + + tex_fonts = { + # Use LaTeX to write all text + "text.usetex": True, + "font.family": "serif", + # Use 10pt font in plots, to match 10pt font in document + "axes.labelsize": 10, + "font.size": 10, + # Make the legend/label fonts a little smaller + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8 + } + + plt.rcParams.update(tex_fonts) + + class Plotter(object): def __init__(self, root_path=''):