SubSpectral and Lightning 0.9 Update
This commit is contained in:
@@ -3,7 +3,8 @@ from pathlib import Path
|
||||
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
||||
# noinspection PyUnresolvedReferences
|
||||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||
|
||||
from .config import Config
|
||||
|
||||
@@ -15,13 +16,13 @@ class Logger(LightningLoggerBase, ABC):
|
||||
@property
|
||||
def experiment(self):
|
||||
if self.debug:
|
||||
return self.testtubelogger.experiment
|
||||
return self.csvlogger.experiment
|
||||
else:
|
||||
return self.neptunelogger.experiment
|
||||
|
||||
@property
|
||||
def log_dir(self):
|
||||
return Path(self.testtubelogger.experiment.get_logdir()).parent
|
||||
return Path(self.csvlogger.experiment.log_dir)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@@ -64,55 +65,56 @@ class Logger(LightningLoggerBase, ABC):
|
||||
self.config.set('project', 'owner', 'testuser')
|
||||
self.config.set('project', 'name', 'test')
|
||||
self.config.set('project', 'neptune_key', 'XXX')
|
||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
experiment_name=self.name,
|
||||
project_name=self.project_name,
|
||||
params=self.config.model_paramters)
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
||||
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||||
self.log_config_as_ini()
|
||||
|
||||
def log_hyperparams(self, params):
|
||||
self.neptunelogger.log_hyperparams(params)
|
||||
self.testtubelogger.log_hyperparams(params)
|
||||
self.csvlogger.log_hyperparams(params)
|
||||
pass
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
self.neptunelogger.log_metrics(metrics, step=step)
|
||||
self.testtubelogger.log_metrics(metrics, step=step)
|
||||
self.csvlogger.log_metrics(metrics, step=step)
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.testtubelogger.close()
|
||||
self.csvlogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir / 'config.ini')
|
||||
|
||||
def log_text(self, name, text, step_nb=0, **kwargs):
|
||||
def log_text(self, name, text, step_nb=0, **_):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text, step_nb)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.csvlogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, ext='png', **kwargs):
|
||||
self.neptunelogger.log_image(name, image, **kwargs)
|
||||
|
||||
step = kwargs.get('step', None)
|
||||
name = f'{step}_{name}' if step is not None else name
|
||||
name = f'{name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||
image_name = f'{step}_{name}' if step is not None else name
|
||||
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
||||
image.savefig(self.log_dir / self.media_dir / name)
|
||||
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
|
||||
self.neptunelogger.log_image(name, str(image_path), **kwargs)
|
||||
|
||||
def save(self):
|
||||
self.testtubelogger.save()
|
||||
self.csvlogger.save()
|
||||
self.neptunelogger.save()
|
||||
|
||||
def finalize(self, status):
|
||||
self.testtubelogger.finalize(status)
|
||||
self.csvlogger.finalize(status)
|
||||
self.neptunelogger.finalize(status)
|
||||
|
||||
def __enter__(self):
|
||||
|
@@ -20,7 +20,7 @@ class ModelParameters(Namespace, Mapping):
|
||||
|
||||
paramter_mapping.update(
|
||||
dict(
|
||||
activation=self._activations[self['activation']]
|
||||
activation=self.__getattribute__('activation')
|
||||
)
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ class ModelParameters(Namespace, Mapping):
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name == 'activation':
|
||||
return self._activations[self['activation']]
|
||||
return self._activations[self['activation'].lower()]
|
||||
else:
|
||||
try:
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
@@ -56,6 +56,7 @@ class ModelParameters(Namespace, Mapping):
|
||||
|
||||
_activations = dict(
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
elu=nn.ELU,
|
||||
relu=nn.ReLU,
|
||||
sigmoid=nn.Sigmoid,
|
||||
tanh=nn.Tanh
|
||||
|
Reference in New Issue
Block a user