diff --git a/_templates/new_project/models/__init__.py b/_templates/new_project/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/config.py b/utils/config.py index 68f6595..b7040f9 100644 --- a/utils/config.py +++ b/utils/config.py @@ -1,4 +1,6 @@ import ast +from pathlib import Path + from copy import deepcopy from abc import ABC @@ -6,7 +8,6 @@ from abc import ABC from argparse import Namespace, ArgumentParser from collections import defaultdict from configparser import ConfigParser -from pathlib import Path import hashlib @@ -38,11 +39,26 @@ class Config(ConfigParser, ABC): def fingerprint(self): h = hashlib.md5() params = deepcopy(self.as_dict) - del params['model']['type'] - del params['data']['worker'] - del params['data']['refresh'] - del params['main'] - del params['project'] + try: + del params['model']['type'] + except KeyError: + pass + try: + del params['data']['worker'] + except KeyError: + pass + try: + del params['data']['refresh'] + except KeyError: + pass + try: + del params['main'] + except KeyError: + pass + try: + del params['project'] + except KeyError: + pass # Flatten the dict of dicts for section in list(params.keys()): params.update({f'{section}_{key}': val for key, val in params[section].items()}) @@ -59,6 +75,7 @@ class Config(ConfigParser, ABC): @property def _model_map(self): + """ This is function is supposed to return a dict, which holds a mapping from string model names to model classes @@ -68,7 +85,6 @@ class Config(ConfigParser, ABC): ) :return: """ - raise NotImplementedError @property diff --git a/utils/logging.py b/utils/logging.py index 8406ebb..876fd4a 100644 --- a/utils/logging.py +++ b/utils/logging.py @@ -62,6 +62,11 @@ class Logger(LightningLoggerBase, ABC): self.config = config self.debug = self.config.main.debug + if self.debug: + self.config.add_section('project') + self.config.set('project', 'owner', 'testuser') + self.config.set('project', 'name', 'test') + self.config.set('project', 'neptune_key', 'XXX') self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) self._neptune_kwargs = dict(offline_mode=self.debug, api_key=self.config.project.neptune_key, @@ -97,10 +102,12 @@ class Logger(LightningLoggerBase, ABC): self.testtubelogger.log_metrics(dict(metric_name=metric_value)) self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) - def log_image(self, name, image, **kwargs): + def log_image(self, name, image, ext='png', **kwargs): self.neptunelogger.log_image(name, image, **kwargs) step = kwargs.get('step', None) name = f'{step}_{name}' if step is not None else name + name = f'{name}.{ext[1:] if ext.startswith(".") else ext}' + (self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True) image.savefig(self.log_dir / self.media_dir / name) def save(self):