intitial thoughts

This commit is contained in:
Si11ium 2020-08-04 09:04:04 +02:00
parent c7d17a9898
commit 4b089729b2
3 changed files with 31 additions and 8 deletions

View File

@ -1,4 +1,6 @@
import ast
from pathlib import Path
from copy import deepcopy
from abc import ABC
@ -6,7 +8,6 @@ from abc import ABC
from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
import hashlib
@ -38,11 +39,26 @@ class Config(ConfigParser, ABC):
def fingerprint(self):
h = hashlib.md5()
params = deepcopy(self.as_dict)
del params['model']['type']
del params['data']['worker']
del params['data']['refresh']
del params['main']
del params['project']
try:
del params['model']['type']
except KeyError:
pass
try:
del params['data']['worker']
except KeyError:
pass
try:
del params['data']['refresh']
except KeyError:
pass
try:
del params['main']
except KeyError:
pass
try:
del params['project']
except KeyError:
pass
# Flatten the dict of dicts
for section in list(params.keys()):
params.update({f'{section}_{key}': val for key, val in params[section].items()})
@ -59,6 +75,7 @@ class Config(ConfigParser, ABC):
@property
def _model_map(self):
"""
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
@ -68,7 +85,6 @@ class Config(ConfigParser, ABC):
)
:return:
"""
raise NotImplementedError
@property

View File

@ -62,6 +62,11 @@ class Logger(LightningLoggerBase, ABC):
self.config = config
self.debug = self.config.main.debug
if self.debug:
self.config.add_section('project')
self.config.set('project', 'owner', 'testuser')
self.config.set('project', 'name', 'test')
self.config.set('project', 'neptune_key', 'XXX')
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key,
@ -97,10 +102,12 @@ class Logger(LightningLoggerBase, ABC):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, **kwargs):
def log_image(self, name, image, ext='png', **kwargs):
self.neptunelogger.log_image(name, image, **kwargs)
step = kwargs.get('step', None)
name = f'{step}_{name}' if step is not None else name
name = f'{name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(self.log_dir / self.media_dir / name)
def save(self):