intitial thoughts
This commit is contained in:
parent
c7d17a9898
commit
4b089729b2
0
_templates/new_project/models/__init__.py
Normal file
0
_templates/new_project/models/__init__.py
Normal file
@ -1,4 +1,6 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
from abc import ABC
|
||||
@ -6,7 +8,6 @@ from abc import ABC
|
||||
from argparse import Namespace, ArgumentParser
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
|
||||
|
||||
@ -38,11 +39,26 @@ class Config(ConfigParser, ABC):
|
||||
def fingerprint(self):
|
||||
h = hashlib.md5()
|
||||
params = deepcopy(self.as_dict)
|
||||
del params['model']['type']
|
||||
del params['data']['worker']
|
||||
del params['data']['refresh']
|
||||
del params['main']
|
||||
del params['project']
|
||||
try:
|
||||
del params['model']['type']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['data']['worker']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['data']['refresh']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['main']
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del params['project']
|
||||
except KeyError:
|
||||
pass
|
||||
# Flatten the dict of dicts
|
||||
for section in list(params.keys()):
|
||||
params.update({f'{section}_{key}': val for key, val in params[section].items()})
|
||||
@ -59,6 +75,7 @@ class Config(ConfigParser, ABC):
|
||||
|
||||
@property
|
||||
def _model_map(self):
|
||||
|
||||
"""
|
||||
This is function is supposed to return a dict, which holds a mapping from string model names to model classes
|
||||
|
||||
@ -68,7 +85,6 @@ class Config(ConfigParser, ABC):
|
||||
)
|
||||
:return:
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
|
@ -62,6 +62,11 @@ class Logger(LightningLoggerBase, ABC):
|
||||
|
||||
self.config = config
|
||||
self.debug = self.config.main.debug
|
||||
if self.debug:
|
||||
self.config.add_section('project')
|
||||
self.config.set('project', 'owner', 'testuser')
|
||||
self.config.set('project', 'name', 'test')
|
||||
self.config.set('project', 'neptune_key', 'XXX')
|
||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
@ -97,10 +102,12 @@ class Logger(LightningLoggerBase, ABC):
|
||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, **kwargs):
|
||||
def log_image(self, name, image, ext='png', **kwargs):
|
||||
self.neptunelogger.log_image(name, image, **kwargs)
|
||||
step = kwargs.get('step', None)
|
||||
name = f'{step}_{name}' if step is not None else name
|
||||
name = f'{name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
||||
image.savefig(self.log_dir / self.media_dir / name)
|
||||
|
||||
def save(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user