intitial thoughts

This commit is contained in:
Si11ium 2020-08-04 09:04:04 +02:00
parent c7d17a9898
commit 4b089729b2
3 changed files with 31 additions and 8 deletions

View File

@ -1,4 +1,6 @@
import ast import ast
from pathlib import Path
from copy import deepcopy from copy import deepcopy
from abc import ABC from abc import ABC
@ -6,7 +8,6 @@ from abc import ABC
from argparse import Namespace, ArgumentParser from argparse import Namespace, ArgumentParser
from collections import defaultdict from collections import defaultdict
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path
import hashlib import hashlib
@ -38,11 +39,26 @@ class Config(ConfigParser, ABC):
def fingerprint(self): def fingerprint(self):
h = hashlib.md5() h = hashlib.md5()
params = deepcopy(self.as_dict) params = deepcopy(self.as_dict)
del params['model']['type'] try:
del params['data']['worker'] del params['model']['type']
del params['data']['refresh'] except KeyError:
del params['main'] pass
del params['project'] try:
del params['data']['worker']
except KeyError:
pass
try:
del params['data']['refresh']
except KeyError:
pass
try:
del params['main']
except KeyError:
pass
try:
del params['project']
except KeyError:
pass
# Flatten the dict of dicts # Flatten the dict of dicts
for section in list(params.keys()): for section in list(params.keys()):
params.update({f'{section}_{key}': val for key, val in params[section].items()}) params.update({f'{section}_{key}': val for key, val in params[section].items()})
@ -59,6 +75,7 @@ class Config(ConfigParser, ABC):
@property @property
def _model_map(self): def _model_map(self):
""" """
This is function is supposed to return a dict, which holds a mapping from string model names to model classes This is function is supposed to return a dict, which holds a mapping from string model names to model classes
@ -68,7 +85,6 @@ class Config(ConfigParser, ABC):
) )
:return: :return:
""" """
raise NotImplementedError raise NotImplementedError
@property @property

View File

@ -62,6 +62,11 @@ class Logger(LightningLoggerBase, ABC):
self.config = config self.config = config
self.debug = self.config.main.debug self.debug = self.config.main.debug
if self.debug:
self.config.add_section('project')
self.config.set('project', 'owner', 'testuser')
self.config.set('project', 'name', 'test')
self.config.set('project', 'neptune_key', 'XXX')
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name) self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug, self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key, api_key=self.config.project.neptune_key,
@ -97,10 +102,12 @@ class Logger(LightningLoggerBase, ABC):
self.testtubelogger.log_metrics(dict(metric_name=metric_value)) self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs) self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, **kwargs): def log_image(self, name, image, ext='png', **kwargs):
self.neptunelogger.log_image(name, image, **kwargs) self.neptunelogger.log_image(name, image, **kwargs)
step = kwargs.get('step', None) step = kwargs.get('step', None)
name = f'{step}_{name}' if step is not None else name name = f'{step}_{name}' if step is not None else name
name = f'{name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(self.log_dir / self.media_dir / name) image.savefig(self.log_dir / self.media_dir / name)
def save(self): def save(self):