186 lines
6.7 KiB
Python
186 lines
6.7 KiB
Python
import inspect
|
||
from argparse import ArgumentParser
|
||
from copy import deepcopy
|
||
|
||
import hashlib
|
||
from pathlib import Path
|
||
|
||
import os
|
||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||
from neptune.api_exceptions import ProjectNotFound
|
||
|
||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||
|
||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||
from pytorch_lightning.utilities import argparse_utils
|
||
|
||
from ml_lib.utils.tools import add_argparse_args
|
||
|
||
|
||
class LightningLogger(LightningLoggerBase):
|
||
|
||
@classmethod
|
||
def from_argparse_args(cls, args, **kwargs):
|
||
cleaned_args = deepcopy(args.__dict__)
|
||
|
||
# Clean Seed and other attributes
|
||
# TODO: Find a better way in cleaning this
|
||
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
|
||
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
|
||
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
|
||
|
||
try:
|
||
del cleaned_args[attr]
|
||
except KeyError:
|
||
pass
|
||
|
||
kwargs.update(params=cleaned_args)
|
||
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||
return new_logger
|
||
|
||
@property
|
||
def fingerprint(self):
|
||
h = hashlib.md5()
|
||
h.update(self._finger_print_string.encode())
|
||
fingerprint = h.hexdigest()
|
||
return fingerprint
|
||
|
||
@property
|
||
def name(self):
|
||
short_name = "".join(c for c in self.model_name if c.isupper())
|
||
return f'{short_name}_{self.fingerprint}'
|
||
|
||
media_dir = 'media'
|
||
|
||
@classmethod
|
||
def add_argparse_args(cls, parent_parser):
|
||
return add_argparse_args(cls, parent_parser)
|
||
|
||
@property
|
||
def experiment(self):
|
||
if self.debug:
|
||
return self.csvlogger.experiment
|
||
else:
|
||
return self.neptunelogger.experiment
|
||
|
||
@property
|
||
def log_dir(self):
|
||
return Path(self.csvlogger.experiment.log_dir)
|
||
|
||
@property
|
||
def project_name(self):
|
||
return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
|
||
|
||
@property
|
||
def projeect_root(self):
|
||
root_path = Path(os.getcwd()).name if not self.debug else 'test'
|
||
return root_path
|
||
|
||
@property
|
||
def version(self):
|
||
return self.seed
|
||
|
||
@property
|
||
def save_dir(self):
|
||
return self.log_dir
|
||
|
||
@property
|
||
def outpath(self):
|
||
return Path(self.root_out) / self.model_name
|
||
|
||
def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
|
||
"""
|
||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||
viewed in experiments view as a column.
|
||
properties (dict|None): Optional default is {}. Properties of the experiment.
|
||
They are editable after experiment is created. Properties are displayed in the experiment’s Details and
|
||
each key-value pair can be viewed in experiments view as a column.
|
||
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
|
||
They are editable after experiment is created (see: append_tag() and remove_tag()).
|
||
Tags are displayed in the experiment’s Details and can be viewed in experiments view as a column.
|
||
"""
|
||
super(LightningLogger, self).__init__()
|
||
|
||
self.debug = debug
|
||
self.owner = owner if not self.debug else 'testuser'
|
||
self.neptune_key = neptune_key if not self.debug else 'XXX'
|
||
self.root_out = outpath if not self.debug else 'debug_out'
|
||
|
||
self.params = params
|
||
|
||
self.seed = seed
|
||
self.model_name = model_name
|
||
|
||
if self.params:
|
||
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
|
||
self._finger_print_string = str(fingerprint_tuple)
|
||
else:
|
||
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
|
||
self.params.update(fingerprint=self.fingerprint)
|
||
|
||
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||
params=self.params,
|
||
api_key=self.neptune_key,
|
||
experiment_name=self.name,
|
||
# tags=?,
|
||
project_name=self.project_name)
|
||
try:
|
||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||
except ProjectNotFound as e:
|
||
print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
|
||
print(e)
|
||
|
||
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||
if self.params:
|
||
self.log_hyperparams(self.params)
|
||
|
||
def close(self):
|
||
self.csvlogger.close()
|
||
self.neptunelogger.close()
|
||
|
||
def set_fingerprint_string(self, fingerprint_str):
|
||
self._finger_print_string = fingerprint_str
|
||
|
||
def log_text(self, name, text, **_):
|
||
# TODO Implement Offline variant.
|
||
self.neptunelogger.log_text(name, text)
|
||
|
||
def log_hyperparams(self, params):
|
||
self.neptunelogger.log_hyperparams(params)
|
||
self.csvlogger.log_hyperparams(params)
|
||
pass
|
||
|
||
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
|
||
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
|
||
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
|
||
pass
|
||
|
||
def log_metrics(self, metrics, step=None):
|
||
self.neptunelogger.log_metrics(metrics, step=step)
|
||
self.csvlogger.log_metrics(metrics, step=step)
|
||
pass
|
||
|
||
def log_image(self, name, image, ext='png', step=None, **kwargs):
|
||
image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
|
||
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
||
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
|
||
self.neptunelogger.log_image(name, str(image_path), **kwargs)
|
||
|
||
def save(self):
|
||
self.csvlogger.save()
|
||
self.neptunelogger.save()
|
||
|
||
def finalize(self, status):
|
||
self.csvlogger.finalize(status)
|
||
self.neptunelogger.finalize(status)
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self.finalize('success')
|
||
pass
|