ml_lib/utils/loggers.py
2021-05-11 10:31:34 +02:00

186 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import inspect
from argparse import ArgumentParser
from copy import deepcopy
import hashlib
from pathlib import Path
import os
from pytorch_lightning.loggers.base import LightningLoggerBase
from neptune.api_exceptions import ProjectNotFound
from pytorch_lightning.loggers.neptune import NeptuneLogger
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.utilities import argparse_utils
from ml_lib.utils.tools import add_argparse_args
class LightningLogger(LightningLoggerBase):
@classmethod
def from_argparse_args(cls, args, **kwargs):
cleaned_args = deepcopy(args.__dict__)
# Clean Seed and other attributes
# TODO: Find a better way in cleaning this
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
try:
del cleaned_args[attr]
except KeyError:
pass
kwargs.update(params=cleaned_args)
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
return new_logger
@property
def fingerprint(self):
h = hashlib.md5()
h.update(self._finger_print_string.encode())
fingerprint = h.hexdigest()
return fingerprint
@property
def name(self):
short_name = "".join(c for c in self.model_name if c.isupper())
return f'{short_name}_{self.fingerprint}'
media_dir = 'media'
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
@property
def experiment(self):
if self.debug:
return self.csvlogger.experiment
else:
return self.neptunelogger.experiment
@property
def log_dir(self):
return Path(self.csvlogger.experiment.log_dir)
@property
def project_name(self):
return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
@property
def projeect_root(self):
root_path = Path(os.getcwd()).name if not self.debug else 'test'
return root_path
@property
def version(self):
return self.seed
@property
def save_dir(self):
return self.log_dir
@property
def outpath(self):
return Path(self.root_out) / self.model_name
def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
"""
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be
viewed in experiments view as a column.
properties (dict|None): Optional default is {}. Properties of the experiment.
They are editable after experiment is created. Properties are displayed in the experiments Details and
each key-value pair can be viewed in experiments view as a column.
tags (list|None): Optional default []. Must be list of str. Tags of the experiment.
They are editable after experiment is created (see: append_tag() and remove_tag()).
Tags are displayed in the experiments Details and can be viewed in experiments view as a column.
"""
super(LightningLogger, self).__init__()
self.debug = debug
self.owner = owner if not self.debug else 'testuser'
self.neptune_key = neptune_key if not self.debug else 'XXX'
self.root_out = outpath if not self.debug else 'debug_out'
self.params = params
self.seed = seed
self.model_name = model_name
if self.params:
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
self._finger_print_string = str(fingerprint_tuple)
else:
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
self.params.update(fingerprint=self.fingerprint)
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
params=self.params,
api_key=self.neptune_key,
experiment_name=self.name,
# tags=?,
project_name=self.project_name)
try:
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
except ProjectNotFound as e:
print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
print(e)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
if self.params:
self.log_hyperparams(self.params)
def close(self):
self.csvlogger.close()
self.neptunelogger.close()
def set_fingerprint_string(self, fingerprint_str):
self._finger_print_string = fingerprint_str
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text)
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
self.csvlogger.log_hyperparams(params)
pass
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
pass
def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step)
self.csvlogger.log_metrics(metrics, step=step)
pass
def log_image(self, name, image, ext='png', step=None, **kwargs):
image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
self.neptunelogger.log_image(name, str(image_path), **kwargs)
def save(self):
self.csvlogger.save()
self.neptunelogger.save()
def finalize(self, status):
self.csvlogger.finalize(status)
self.neptunelogger.finalize(status)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.finalize('success')
pass