SubSpectral and Lightning 0.9 Update

This commit is contained in:
Si11ium
2020-09-25 15:35:15 +02:00
parent 6bc9447ce1
commit 5848b528f0
13 changed files with 197 additions and 630 deletions

View File

@@ -3,7 +3,8 @@ from pathlib import Path
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.test_tube import TestTubeLogger
# noinspection PyUnresolvedReferences
from pytorch_lightning.loggers.csv_logs import CSVLogger
from .config import Config
@@ -15,13 +16,13 @@ class Logger(LightningLoggerBase, ABC):
@property
def experiment(self):
if self.debug:
return self.testtubelogger.experiment
return self.csvlogger.experiment
else:
return self.neptunelogger.experiment
@property
def log_dir(self):
return Path(self.testtubelogger.experiment.get_logdir()).parent
return Path(self.csvlogger.experiment.log_dir)
@property
def name(self):
@@ -64,55 +65,56 @@ class Logger(LightningLoggerBase, ABC):
self.config.set('project', 'owner', 'testuser')
self.config.set('project', 'name', 'test')
self.config.set('project', 'neptune_key', 'XXX')
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key,
experiment_name=self.name,
project_name=self.project_name,
params=self.config.model_paramters)
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
self.log_config_as_ini()
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
self.testtubelogger.log_hyperparams(params)
self.csvlogger.log_hyperparams(params)
pass
def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step)
self.testtubelogger.log_metrics(metrics, step=step)
self.csvlogger.log_metrics(metrics, step=step)
pass
def close(self):
self.testtubelogger.close()
self.csvlogger.close()
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir / 'config.ini')
def log_text(self, name, text, step_nb=0, **kwargs):
def log_text(self, name, text, step_nb=0, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text, step_nb)
def log_metric(self, metric_name, metric_value, **kwargs):
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
self.csvlogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, ext='png', **kwargs):
self.neptunelogger.log_image(name, image, **kwargs)
step = kwargs.get('step', None)
name = f'{step}_{name}' if step is not None else name
name = f'{name}.{ext[1:] if ext.startswith(".") else ext}'
image_name = f'{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(self.log_dir / self.media_dir / name)
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
self.neptunelogger.log_image(name, str(image_path), **kwargs)
def save(self):
self.testtubelogger.save()
self.csvlogger.save()
self.neptunelogger.save()
def finalize(self, status):
self.testtubelogger.finalize(status)
self.csvlogger.finalize(status)
self.neptunelogger.finalize(status)
def __enter__(self):

View File

@@ -20,7 +20,7 @@ class ModelParameters(Namespace, Mapping):
paramter_mapping.update(
dict(
activation=self._activations[self['activation']]
activation=self.__getattribute__('activation')
)
)
@@ -44,7 +44,7 @@ class ModelParameters(Namespace, Mapping):
def __getattribute__(self, name):
if name == 'activation':
return self._activations[self['activation']]
return self._activations[self['activation'].lower()]
else:
try:
return super(ModelParameters, self).__getattribute__(name)
@@ -56,6 +56,7 @@ class ModelParameters(Namespace, Mapping):
_activations = dict(
leaky_relu=nn.LeakyReLU,
elu=nn.ELU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,
tanh=nn.Tanh