SubSpectral and Lightning 0.9 Update
This commit is contained in:
parent
6bc9447ce1
commit
5848b528f0
@ -7,10 +7,9 @@ import torch
|
|||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
|
|
||||||
from modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
from utils.logging import Logger
|
from ml_lib.utils.logging import Logger
|
||||||
from utils.model_io import SavedLightningModels
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from _templates.new_project.utils.project_config import Config
|
from ml_lib._templates.new_project.utils.project_config import Config
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
@ -8,7 +8,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
|||||||
# Imports
|
# Imports
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
from _templates.new_project.main import run_lightning_loop, args
|
from ml_lib._templates.new_project.main import run_lightning_loop, args
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -11,13 +11,13 @@ from torch.utils.data import DataLoader
|
|||||||
from torchcontrib.optim import SWA
|
from torchcontrib.optim import SWA
|
||||||
from torchvision.transforms import Compose
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
from _templates.new_project.datasets.template_dataset import TemplateDataset
|
from ml_lib._templates.new_project.datasets.template_dataset import TemplateDataset
|
||||||
|
|
||||||
from audio_toolset.audio_io import NormalizeLocal
|
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||||
from modules.utils import LightningBaseModule
|
from ml_lib.modules.util import LightningBaseModule
|
||||||
from utils.transforms import ToTensor
|
from ml_lib.utils.transforms import ToTensor
|
||||||
|
|
||||||
from _templates.new_project.utils.project_config import GlobalVar as GlobalVars
|
from ml_lib._templates.new_project.utils.project_config import GlobalVar as GlobalVars
|
||||||
|
|
||||||
|
|
||||||
class BaseOptimizerMixin:
|
class BaseOptimizerMixin:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
|
|
||||||
|
|
||||||
class GlobalVar(Namespace):
|
class GlobalVar(Namespace):
|
||||||
|
@ -1,488 +0,0 @@
|
|||||||
##########################
|
|
||||||
# constants
|
|
||||||
import argparse
|
|
||||||
import contextlib
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Optional, Union, Any
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
# ToDo: Check this
|
|
||||||
import shutil
|
|
||||||
from imageio import imwrite
|
|
||||||
from natsort import natsorted
|
|
||||||
from pytorch_lightning.loggers import LightningLoggerBase
|
|
||||||
from pytorch_lightning import _logger as log
|
|
||||||
from test_tube.log import DDPExperiment
|
|
||||||
|
|
||||||
_ROOT = Path(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# Experiment object
|
|
||||||
# -----------------------------
|
|
||||||
class Experiment(object):
|
|
||||||
|
|
||||||
def __init__(self, save_dir=None, name='default', debug=False, version=None, autosave=False, description=None):
|
|
||||||
"""
|
|
||||||
A new Experiment object defaults to 'default' unless a specific name is provided
|
|
||||||
If a known name is already provided, then the file version is changed
|
|
||||||
:param name:
|
|
||||||
:param debug:
|
|
||||||
"""
|
|
||||||
|
|
||||||
# change where the save dir is if requested
|
|
||||||
|
|
||||||
if save_dir is not None:
|
|
||||||
global _ROOT
|
|
||||||
_ROOT = save_dir
|
|
||||||
|
|
||||||
self.save_dir = save_dir
|
|
||||||
self.no_save_dir = save_dir is None
|
|
||||||
self.metrics = []
|
|
||||||
self.tags = {}
|
|
||||||
self.name = name
|
|
||||||
self.debug = debug
|
|
||||||
self.version = version
|
|
||||||
self.autosave = autosave
|
|
||||||
self.description = description
|
|
||||||
self.exp_hash = '{}_v{}'.format(self.name, version)
|
|
||||||
self.created_at = str(datetime.utcnow())
|
|
||||||
self.process = os.getpid()
|
|
||||||
|
|
||||||
# when debugging don't do anything else
|
|
||||||
if debug:
|
|
||||||
return
|
|
||||||
|
|
||||||
# update version hash if we need to increase version on our own
|
|
||||||
# we will increase the previous version, so do it now so the hash
|
|
||||||
# is accurate
|
|
||||||
if version is None:
|
|
||||||
old_version = self.__get_last_experiment_version()
|
|
||||||
self.exp_hash = '{}_v{}'.format(self.name, old_version + 1)
|
|
||||||
self.version = old_version + 1
|
|
||||||
|
|
||||||
# create a new log file
|
|
||||||
self.__init_cache_file_if_needed()
|
|
||||||
|
|
||||||
# when we have a version, load it
|
|
||||||
if self.version is not None:
|
|
||||||
|
|
||||||
# when no version and no file, create it
|
|
||||||
if not os.path.exists(self.__get_log_name()):
|
|
||||||
self.__create_exp_file(self.version)
|
|
||||||
else:
|
|
||||||
# otherwise load it
|
|
||||||
self.__load()
|
|
||||||
|
|
||||||
else:
|
|
||||||
# if no version given, increase the version to a new exp
|
|
||||||
# create the file if not exists
|
|
||||||
old_version = self.__get_last_experiment_version()
|
|
||||||
self.version = old_version
|
|
||||||
self.__create_exp_file(self.version + 1)
|
|
||||||
|
|
||||||
def get_meta_copy(self):
|
|
||||||
"""
|
|
||||||
Gets a meta-version only copy of this module
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return DDPExperiment(self)
|
|
||||||
|
|
||||||
def on_exit(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __clean_dir(self):
|
|
||||||
files = os.listdir(self.save_dir)
|
|
||||||
|
|
||||||
for f in files:
|
|
||||||
if str(self.process) in f:
|
|
||||||
os.remove(os.path.join(self.save_dir, f))
|
|
||||||
|
|
||||||
def argparse(self, argparser):
|
|
||||||
parsed = vars(argparser)
|
|
||||||
to_add = {}
|
|
||||||
|
|
||||||
# don't store methods
|
|
||||||
for k, v in parsed.items():
|
|
||||||
if not callable(v):
|
|
||||||
to_add[k] = v
|
|
||||||
|
|
||||||
self.tag(to_add)
|
|
||||||
|
|
||||||
def add_meta_from_hyperopt(self, hypo):
|
|
||||||
"""
|
|
||||||
Transfers meta data about all the params from the
|
|
||||||
hyperoptimizer to the log
|
|
||||||
:param hypo:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
meta = hypo.get_current_trial_meta()
|
|
||||||
for tag in meta:
|
|
||||||
self.tag(tag)
|
|
||||||
|
|
||||||
# --------------------------------
|
|
||||||
# FILE IO UTILS
|
|
||||||
# --------------------------------
|
|
||||||
def __init_cache_file_if_needed(self):
|
|
||||||
"""
|
|
||||||
Inits a file that we log historical experiments
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
exp_cache_file = self.get_data_path(self.name, self.version)
|
|
||||||
if not os.path.isdir(exp_cache_file):
|
|
||||||
os.makedirs(exp_cache_file, exist_ok=True)
|
|
||||||
except FileExistsError:
|
|
||||||
# file already exists (likely written by another exp. In this case disable the experiment
|
|
||||||
self.debug = True
|
|
||||||
|
|
||||||
def __create_exp_file(self, version):
|
|
||||||
"""
|
|
||||||
Recreates the old file with this exp and version
|
|
||||||
:param version:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
exp_cache_file = self.get_data_path(self.name, self.version)
|
|
||||||
# if no exp, then make it
|
|
||||||
path = exp_cache_file / 'meta.experiment'
|
|
||||||
path.touch(exist_ok=True)
|
|
||||||
|
|
||||||
self.version = version
|
|
||||||
|
|
||||||
# make the directory for the experiment media assets name
|
|
||||||
self.get_media_path(self.name, self.version).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
except FileExistsError:
|
|
||||||
# file already exists (likely written by another exp. In this case disable the experiment
|
|
||||||
self.debug = True
|
|
||||||
|
|
||||||
def __get_last_experiment_version(self):
|
|
||||||
|
|
||||||
exp_cache_file = self.get_data_path(self.name, self.version).parent
|
|
||||||
last_version = -1
|
|
||||||
|
|
||||||
version = natsorted([x.name for x in exp_cache_file.iterdir() if 'version_' in x.name])[-1]
|
|
||||||
last_version = max(last_version, int(version.split('_')[1]))
|
|
||||||
|
|
||||||
return last_version
|
|
||||||
|
|
||||||
def __get_log_name(self):
|
|
||||||
return self.get_data_path(self.name, self.version) / 'meta.experiment'
|
|
||||||
|
|
||||||
def tag(self, tag_dict):
|
|
||||||
"""
|
|
||||||
Adds a tag to the experiment.
|
|
||||||
Tags are metadata for the exp.
|
|
||||||
|
|
||||||
>> e.tag({"model": "Convnet A"})
|
|
||||||
|
|
||||||
:param tag_dict:
|
|
||||||
:type tag_dict: dict
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if self.debug:
|
|
||||||
return
|
|
||||||
|
|
||||||
# parse tags
|
|
||||||
for k, v in tag_dict.items():
|
|
||||||
self.tags[k] = v
|
|
||||||
|
|
||||||
# save if needed
|
|
||||||
if self.autosave:
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def log(self, metrics_dict):
|
|
||||||
"""
|
|
||||||
Adds a json dict of metrics.
|
|
||||||
|
|
||||||
>> e.log({"loss": 23, "coeff_a": 0.2})
|
|
||||||
|
|
||||||
:param metrics_dict:
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if self.debug:
|
|
||||||
return
|
|
||||||
|
|
||||||
new_metrics_dict = metrics_dict.copy()
|
|
||||||
for k, v in metrics_dict.items():
|
|
||||||
tmp_metrics_dict = new_metrics_dict.pop(k)
|
|
||||||
new_metrics_dict.update(tmp_metrics_dict)
|
|
||||||
|
|
||||||
metrics_dict = new_metrics_dict
|
|
||||||
|
|
||||||
# timestamp
|
|
||||||
if 'created_at' not in metrics_dict:
|
|
||||||
metrics_dict['created_at'] = str(datetime.utcnow())
|
|
||||||
|
|
||||||
self.__convert_numpy_types(metrics_dict)
|
|
||||||
|
|
||||||
self.metrics.append(metrics_dict)
|
|
||||||
|
|
||||||
if self.autosave:
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __convert_numpy_types(metrics_dict):
|
|
||||||
for k, v in metrics_dict.items():
|
|
||||||
if v.__class__.__name__ == 'float32':
|
|
||||||
metrics_dict[k] = float(v)
|
|
||||||
|
|
||||||
if v.__class__.__name__ == 'float64':
|
|
||||||
metrics_dict[k] = float(v)
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
"""
|
|
||||||
Saves current experiment progress
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if self.debug:
|
|
||||||
return
|
|
||||||
|
|
||||||
# save images and replace the image array with the
|
|
||||||
# file name
|
|
||||||
self.__save_images(self.metrics)
|
|
||||||
metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv'
|
|
||||||
meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv'
|
|
||||||
|
|
||||||
obj = {
|
|
||||||
'name': self.name,
|
|
||||||
'version': self.version,
|
|
||||||
'tags_path': meta_tags_path,
|
|
||||||
'metrics_path': metrics_file_path,
|
|
||||||
'autosave': self.autosave,
|
|
||||||
'description': self.description,
|
|
||||||
'created_at': self.created_at,
|
|
||||||
'exp_hash': self.exp_hash
|
|
||||||
}
|
|
||||||
|
|
||||||
# save the experiment meta file
|
|
||||||
with atomic_write(self.__get_log_name()) as tmp_path:
|
|
||||||
with open(tmp_path, 'w') as file:
|
|
||||||
json.dump(obj, file, ensure_ascii=False)
|
|
||||||
|
|
||||||
# save the metatags file
|
|
||||||
df = pd.DataFrame({'key': list(self.tags.keys()), 'value': list(self.tags.values())})
|
|
||||||
with atomic_write(meta_tags_path) as tmp_path:
|
|
||||||
df.to_csv(tmp_path, index=False)
|
|
||||||
|
|
||||||
# save the metrics data
|
|
||||||
df = pd.DataFrame(self.metrics)
|
|
||||||
with atomic_write(metrics_file_path) as tmp_path:
|
|
||||||
df.to_csv(tmp_path, index=False)
|
|
||||||
|
|
||||||
def __save_images(self, metrics):
|
|
||||||
"""
|
|
||||||
Save tags that have a png_ prefix (as images)
|
|
||||||
and replace the meta tag with the file name
|
|
||||||
:param metrics:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# iterate all metrics and find keys with a specific prefix
|
|
||||||
for i, metric in enumerate(metrics):
|
|
||||||
for k, v in metric.items():
|
|
||||||
# if the prefix is a png, save the image and replace the value with the path
|
|
||||||
img_extension = None
|
|
||||||
img_extension = 'png' if 'png_' in k else img_extension
|
|
||||||
img_extension = 'jpg' if 'jpg' in k else img_extension
|
|
||||||
img_extension = 'jpeg' if 'jpeg' in k else img_extension
|
|
||||||
|
|
||||||
if img_extension is not None:
|
|
||||||
# determine the file name
|
|
||||||
img_name = '_'.join(k.split('_')[1:])
|
|
||||||
save_path = self.get_media_path(self.name, self.version)
|
|
||||||
save_path = '{}/{}_{}.{}'.format(save_path, img_name, i, img_extension)
|
|
||||||
|
|
||||||
# save image to disk
|
|
||||||
if type(metric[k]) is not str:
|
|
||||||
imwrite(save_path, metric[k])
|
|
||||||
|
|
||||||
# replace the image in the metric with the file path
|
|
||||||
metric[k] = save_path
|
|
||||||
|
|
||||||
def __load(self):
|
|
||||||
# load .experiment file
|
|
||||||
with open(self.__get_log_name(), 'r') as file:
|
|
||||||
data = json.load(file)
|
|
||||||
self.name = data['name']
|
|
||||||
self.version = data['version']
|
|
||||||
self.autosave = data['autosave']
|
|
||||||
self.created_at = data['created_at']
|
|
||||||
self.description = data['description']
|
|
||||||
self.exp_hash = data['exp_hash']
|
|
||||||
|
|
||||||
# load .tags file
|
|
||||||
meta_tags_path = self.get_data_path(self.name, self.version) / 'meta_tags.csv'
|
|
||||||
df = pd.read_csv(meta_tags_path)
|
|
||||||
self.tags_list = df.to_dict(orient='records')
|
|
||||||
self.tags = {}
|
|
||||||
for d in self.tags_list:
|
|
||||||
k, v = d['key'], d['value']
|
|
||||||
self.tags[k] = v
|
|
||||||
|
|
||||||
# load metrics
|
|
||||||
metrics_file_path = self.get_data_path(self.name, self.version) / 'metrics.csv'
|
|
||||||
try:
|
|
||||||
df = pd.read_csv(metrics_file_path)
|
|
||||||
self.metrics = df.to_dict(orient='records')
|
|
||||||
|
|
||||||
# remove nans and infs
|
|
||||||
for metric in self.metrics:
|
|
||||||
to_delete = []
|
|
||||||
for k, v in metric.items():
|
|
||||||
if np.isnan(v) or np.isinf(v):
|
|
||||||
to_delete.append(k)
|
|
||||||
for k in to_delete:
|
|
||||||
del metric[k]
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# metrics was empty...
|
|
||||||
self.metrics = []
|
|
||||||
|
|
||||||
def get_data_path(self, exp_name, exp_version):
|
|
||||||
"""
|
|
||||||
Returns the path to the local package cache
|
|
||||||
:param exp_name:
|
|
||||||
:param exp_version:
|
|
||||||
:return:
|
|
||||||
Path
|
|
||||||
"""
|
|
||||||
if self.no_save_dir:
|
|
||||||
return _ROOT / 'local_experiment_data' / exp_name, f'version_{exp_version}'
|
|
||||||
else:
|
|
||||||
return _ROOT / exp_name / f'version_{exp_version}'
|
|
||||||
|
|
||||||
def get_media_path(self, exp_name, exp_version):
|
|
||||||
"""
|
|
||||||
Returns the path to the local package cache
|
|
||||||
:param exp_version:
|
|
||||||
:param exp_name:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.get_data_path(exp_name, exp_version) / 'media'
|
|
||||||
|
|
||||||
# ----------------------------
|
|
||||||
# OVERWRITES
|
|
||||||
# ----------------------------
|
|
||||||
def __str__(self):
|
|
||||||
return 'Exp: {}, v: {}'.format(self.name, self.version)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return 'Exp: {}, v: {}'.format(self.name, self.version)
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def atomic_write(dst_path):
|
|
||||||
"""A context manager to simplify atomic writing.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
>>> with atomic_write(dst_path) as tmp_path:
|
|
||||||
>>> # write to tmp_path
|
|
||||||
>>> # Here tmp_path renamed to dst_path, if no exception happened.
|
|
||||||
"""
|
|
||||||
tmp_path = dst_path / '.tmp'
|
|
||||||
try:
|
|
||||||
yield tmp_path
|
|
||||||
except:
|
|
||||||
if tmp_path.exists():
|
|
||||||
tmp_path.unlink()
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
# If everything is fine, move tmp file to the destination.
|
|
||||||
shutil.move(tmp_path, str(dst_path))
|
|
||||||
|
|
||||||
|
|
||||||
##########################
|
|
||||||
class LocalLogger(LightningLoggerBase):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def experiment(self) -> Experiment:
|
|
||||||
r"""
|
|
||||||
|
|
||||||
Actual TestTube object. To use TestTube features in your
|
|
||||||
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
self.logger.experiment.some_test_tube_function()
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self._experiment is not None:
|
|
||||||
return self._experiment
|
|
||||||
|
|
||||||
self._experiment = Experiment(
|
|
||||||
save_dir=self.save_dir,
|
|
||||||
name=self._name,
|
|
||||||
debug=self.debug,
|
|
||||||
version=self.version,
|
|
||||||
description=self.description
|
|
||||||
)
|
|
||||||
return self._experiment
|
|
||||||
|
|
||||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def log_hyperparams(self, params: argparse.Namespace):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def version(self) -> Union[int, str]:
|
|
||||||
if self._version is None:
|
|
||||||
self._version = self._get_next_version()
|
|
||||||
return self._version
|
|
||||||
|
|
||||||
def _get_next_version(self):
|
|
||||||
root_dir = self.save_dir / self.name
|
|
||||||
|
|
||||||
if not root_dir.is_dir():
|
|
||||||
log.warning(f'Missing logger folder: {root_dir}')
|
|
||||||
return 0
|
|
||||||
|
|
||||||
existing_versions = []
|
|
||||||
for d in os.listdir(root_dir):
|
|
||||||
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
|
|
||||||
existing_versions.append(int(d.split("_")[1]))
|
|
||||||
|
|
||||||
if len(existing_versions) == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return max(existing_versions) + 1
|
|
||||||
|
|
||||||
def __init__(self, save_dir: str, name: str = "default", description: Optional[str] = None,
|
|
||||||
debug: bool = False, version: Optional[int] = None, **kwargs):
|
|
||||||
super(LocalLogger, self).__init__(**kwargs)
|
|
||||||
self.save_dir = Path(save_dir)
|
|
||||||
self._name = name
|
|
||||||
self.description = description
|
|
||||||
self.debug = debug
|
|
||||||
self._version = version
|
|
||||||
self._experiment = None
|
|
||||||
|
|
||||||
# Test tube experiments are not pickleable, so we need to override a few
|
|
||||||
# methods to get DDP working. See
|
|
||||||
# https://docs.python.org/3/library/pickle.html#handling-stateful-objects
|
|
||||||
# for more info.
|
|
||||||
def __getstate__(self) -> Dict[Any, Any]:
|
|
||||||
state = self.__dict__.copy()
|
|
||||||
state["_experiment"] = self.experiment.get_meta_copy()
|
|
||||||
return state
|
|
||||||
|
|
||||||
def __setstate__(self, state: Dict[Any, Any]):
|
|
||||||
self._experiment = state["_experiment"].get_non_ddp_exp()
|
|
||||||
del state["_experiment"]
|
|
||||||
self.__dict__.update(state)
|
|
@ -130,8 +130,9 @@ class DeConvModule(ShapeMixin, nn.Module):
|
|||||||
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||||
dropout: Union[int, float] = 0, autopad=0,
|
dropout: Union[int, float] = 0, autopad=0,
|
||||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
||||||
bias=True, norm=False):
|
bias=True, norm=False, **kwargs):
|
||||||
super(DeConvModule, self).__init__()
|
super(DeConvModule, self).__init__()
|
||||||
|
warnings.warn(f'The following arguments have been ignored: \n {list(kwargs.keys())}')
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
self.padding = conv_padding
|
self.padding = conv_padding
|
||||||
self.conv_kernel = conv_kernel
|
self.conv_kernel = conv_kernel
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import ReLU
|
from torch.nn import ReLU
|
||||||
|
try:
|
||||||
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
||||||
|
except ImportError:
|
||||||
|
print('Install torch-geometric to use this package.')
|
||||||
|
|
||||||
|
|
||||||
class SAModule(torch.nn.Module):
|
class SAModule(torch.nn.Module):
|
||||||
|
@ -1,10 +1,77 @@
|
|||||||
#
|
#
|
||||||
# Full Model Parts
|
# Full Model Parts
|
||||||
###################
|
###################
|
||||||
import torch
|
from argparse import Namespace
|
||||||
from torch import nn
|
from typing import Union, List, Tuple
|
||||||
|
|
||||||
from .util import ShapeMixin
|
import torch
|
||||||
|
from abc import ABC
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from .util import ShapeMixin, LightningBaseModule
|
||||||
|
|
||||||
|
|
||||||
|
class AEBaseModule(LightningBaseModule, ABC):
|
||||||
|
|
||||||
|
def generate_random_image(self, dataloader: Union[None, str, DataLoader] = None,
|
||||||
|
lat_min: Union[Tuple, List, None] = None,
|
||||||
|
lat_max: Union[Tuple, List, None] = None):
|
||||||
|
|
||||||
|
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||||
|
|
||||||
|
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||||
|
# assert not any([x is None for x in min_max])
|
||||||
|
lat_min = torch.as_tensor(lat_min or min_max[0])
|
||||||
|
lat_max = lat_max or min_max[1]
|
||||||
|
|
||||||
|
random_z = torch.rand((1, self.lat_dim))
|
||||||
|
random_z = random_z * (abs(lat_min) + lat_max) - abs(lat_min)
|
||||||
|
|
||||||
|
return self.decoder(random_z).squeeze()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
return self.encoder(x).squeeze()
|
||||||
|
|
||||||
|
def _find_min_max(self, dataloader):
|
||||||
|
encodings = list()
|
||||||
|
for batch in dataloader:
|
||||||
|
encodings.append(self.encode(batch))
|
||||||
|
encodings = torch.cat(encodings, dim=0)
|
||||||
|
min_lat = encodings.min(dim=1)
|
||||||
|
max_lat = encodings.max(dim=1)
|
||||||
|
return min_lat, max_lat
|
||||||
|
|
||||||
|
def decode_lat_evenly(self, n: int,
|
||||||
|
dataloader: Union[None, str, DataLoader] = None,
|
||||||
|
lat_min: Union[Tuple, List, None] = None,
|
||||||
|
lat_max: Union[Tuple, List, None] = None):
|
||||||
|
assert bool(dataloader) ^ bool(lat_min and lat_max), 'Decide wether to give min, max or a dataloader, not both.'
|
||||||
|
|
||||||
|
min_max = self._find_min_max(dataloader) if dataloader else [None, None]
|
||||||
|
|
||||||
|
lat_min = lat_min or min_max[0]
|
||||||
|
lat_max = lat_max or min_max[1]
|
||||||
|
|
||||||
|
random_latent_samples = torch.stack([torch.linspace(lat_min[i].item(), lat_max[i].item(), n)
|
||||||
|
for i in range(self.params.lat_dim)], dim=-1).cpu().detach()
|
||||||
|
return self.decode(random_latent_samples).cpu().detach()
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
if len(z.shape) == 1:
|
||||||
|
z = z.unsqueeze(0)
|
||||||
|
return self.decoder(z).squeeze()
|
||||||
|
|
||||||
|
def encode_and_restore(self, x):
|
||||||
|
x = x.to(self.device)
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
z = self.encode(x)
|
||||||
|
x_hat = self.decode(z)
|
||||||
|
|
||||||
|
return Namespace(main_out=x_hat.squeeze(), latent_out=z)
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
@ -16,9 +83,12 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
def __init__(self, out_channels, re_shape, lat_dim, use_norm=False, use_bias=True, dropout: Union[int, float] = 0,
|
||||||
filters: List[int] = None, activation=nn.ReLU):
|
filters: List[int] = None, kernels: List[int] = None, activation=nn.ReLU, **kwargs):
|
||||||
super(Generator, self).__init__()
|
super(Generator, self).__init__()
|
||||||
assert filters, '"Filters" has to be a list of int len 3'
|
assert filters, '"Filters" has to be a list of int.'
|
||||||
|
assert filters, '"Filters" has to be a list of int.'
|
||||||
|
assert len(filters) == len(kernels), '"Filters" and "Kernels" has to be of same length.'
|
||||||
|
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.inner_activation = activation()
|
self.inner_activation = activation()
|
||||||
@ -29,52 +99,35 @@ class Generator(nn.Module):
|
|||||||
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
# re_shape = (self.feature_mixed_dim // reduce(mul, re_shape[1:]), ) + tuple(re_shape[1:])
|
||||||
|
|
||||||
self.flat = Flatten(to=re_shape)
|
self.flat = Flatten(to=re_shape)
|
||||||
|
self.de_conv_list = nn.ModuleList()
|
||||||
|
|
||||||
self.deconv1 = DeConvModule(re_shape, conv_filters=self.filters[0],
|
last_shape = re_shape
|
||||||
conv_kernel=5,
|
for conv_filter, conv_kernel in zip(filters, kernels):
|
||||||
conv_padding=2,
|
self.de_conv_list.append(DeConvModule(last_shape, conv_filters=self.filters[0],
|
||||||
conv_stride=1,
|
conv_kernel=conv_kernel,
|
||||||
normalize=use_norm,
|
conv_padding=conv_kernel-2,
|
||||||
activation=self.activation,
|
conv_stride=conv_filter,
|
||||||
interpolation_scale=2,
|
normalize=use_norm,
|
||||||
dropout=self.dropout
|
activation=self.activation,
|
||||||
)
|
interpolation_scale=2,
|
||||||
|
dropout=self.dropout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_shape = self.de_conv_list[-1].shape
|
||||||
|
|
||||||
self.deconv2 = DeConvModule(self.deconv1.shape, conv_filters=self.filters[1],
|
self.de_conv_out = DeConvModule(self.de_conv_list[-1].shape, conv_filters=out_channels, conv_kernel=3,
|
||||||
conv_kernel=3,
|
conv_padding=1, activation=self.out_activation
|
||||||
conv_padding=1,
|
)
|
||||||
conv_stride=1,
|
|
||||||
normalize=use_norm,
|
|
||||||
activation=self.activation,
|
|
||||||
interpolation_scale=2,
|
|
||||||
dropout=self.dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.deconv3 = DeConvModule(self.deconv2.shape, conv_filters=self.filters[2],
|
|
||||||
conv_kernel=3,
|
|
||||||
conv_padding=1,
|
|
||||||
conv_stride=1,
|
|
||||||
normalize=use_norm,
|
|
||||||
activation=self.activation,
|
|
||||||
interpolation_scale=2,
|
|
||||||
dropout=self.dropout
|
|
||||||
)
|
|
||||||
|
|
||||||
self.deconv4 = DeConvModule(self.deconv3.shape, conv_filters=out_channels,
|
|
||||||
conv_kernel=3,
|
|
||||||
conv_padding=1,
|
|
||||||
# normalize=norm,
|
|
||||||
activation=self.out_activation
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
tensor = self.l1(z)
|
tensor = self.l1(z)
|
||||||
tensor = self.inner_activation(tensor)
|
tensor = self.inner_activation(tensor)
|
||||||
tensor = self.flat(tensor)
|
tensor = self.flat(tensor)
|
||||||
tensor = self.deconv1(tensor)
|
|
||||||
tensor = self.deconv2(tensor)
|
for de_conv in self.de_conv_list:
|
||||||
tensor = self.deconv3(tensor)
|
tensor = de_conv(tensor)
|
||||||
tensor = self.deconv4(tensor)
|
|
||||||
|
tensor = self.de_conv_out(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
@ -119,12 +172,14 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
|||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
def __init__(self, in_shape, lat_dim=256, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||||
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
latent_activation: Union[nn.Module, None] = None, activation: nn.Module = nn.ELU,
|
||||||
filters: List[int] = None):
|
filters: List[int] = None, kernels: List[int] = None, **kwargs):
|
||||||
super(BaseEncoder, self).__init__()
|
super(BaseEncoder, self).__init__()
|
||||||
assert filters, '"Filters" has to be a list of int len 3'
|
assert filters, '"Filters" has to be a list of int'
|
||||||
|
assert kernels, '"Kernels" has to be a list of int'
|
||||||
|
assert len(kernels) == len(filters), 'Length of "Filters" and "Kernels" has to be same.'
|
||||||
|
|
||||||
# Optional Padding for odd image-sizes
|
# Optional Padding for odd image-sizes
|
||||||
# Obsolet, already Done by autopadding module on incoming tensors
|
# Obsolet, cdan be done by autopadding module on incoming tensors
|
||||||
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
# in_shape = [x+1 if x % 2 != 0 and idx else x for idx, x in enumerate(in_shape)]
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
@ -133,43 +188,29 @@ class BaseEncoder(ShapeMixin, nn.Module):
|
|||||||
self.use_bias = use_bias
|
self.use_bias = use_bias
|
||||||
self.latent_activation = latent_activation() if latent_activation else None
|
self.latent_activation = latent_activation() if latent_activation else None
|
||||||
|
|
||||||
|
self.conv_list = nn.ModuleList()
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.conv1 = ConvModule(self.in_shape, conv_filters=filters[0],
|
last_shape = self.in_shape
|
||||||
conv_kernel=3,
|
for conv_filter, conv_kernel in zip(filters, kernels):
|
||||||
conv_padding=1,
|
self.conv_list.append(ConvModule(last_shape, conv_filters=conv_filter,
|
||||||
conv_stride=1,
|
conv_kernel=conv_kernel,
|
||||||
pooling_size=2,
|
conv_padding=conv_kernel-2,
|
||||||
use_norm=use_norm,
|
conv_stride=1,
|
||||||
dropout=dropout,
|
pooling_size=2,
|
||||||
activation=activation
|
use_norm=use_norm,
|
||||||
)
|
dropout=dropout,
|
||||||
|
activation=activation
|
||||||
self.conv2 = ConvModule(self.conv1.shape, conv_filters=filters[1],
|
)
|
||||||
conv_kernel=3,
|
)
|
||||||
conv_padding=1,
|
last_shape = self.conv_list[-1].shape
|
||||||
conv_stride=1,
|
|
||||||
pooling_size=2,
|
|
||||||
use_norm=use_norm,
|
|
||||||
dropout=dropout,
|
|
||||||
activation=activation
|
|
||||||
)
|
|
||||||
|
|
||||||
self.conv3 = ConvModule(self.conv2.shape, conv_filters=filters[2],
|
|
||||||
conv_kernel=5,
|
|
||||||
conv_padding=2,
|
|
||||||
conv_stride=1,
|
|
||||||
pooling_size=2,
|
|
||||||
use_norm=use_norm,
|
|
||||||
dropout=dropout,
|
|
||||||
activation=activation
|
|
||||||
)
|
|
||||||
|
|
||||||
self.flat = Flatten()
|
self.flat = Flatten()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
tensor = self.conv1(x)
|
tensor = x
|
||||||
tensor = self.conv2(tensor)
|
for conv in self.conv_list:
|
||||||
tensor = self.conv3(tensor)
|
tensor = conv(tensor)
|
||||||
tensor = self.flat(tensor)
|
tensor = self.flat(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
|
from functools import reduce
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from operator import mul
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch import functional as F
|
from torch import functional as F
|
||||||
|
|
||||||
@ -102,6 +105,14 @@ class ShapeMixin:
|
|||||||
else:
|
else:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flat_shape(self):
|
||||||
|
shape = self.shape
|
||||||
|
try:
|
||||||
|
return reduce(mul, shape)
|
||||||
|
except TypeError:
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
class F_x(ShapeMixin, nn.Module):
|
class F_x(ShapeMixin, nn.Module):
|
||||||
def __init__(self, in_shape):
|
def __init__(self, in_shape):
|
||||||
@ -175,7 +186,7 @@ class WeightInit:
|
|||||||
m.bias.data.fill_(0.01)
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
|
||||||
class Filter(nn.Module):
|
class Filter(nn.Module, ShapeMixin):
|
||||||
|
|
||||||
def __init__(self, in_shape, pos, dim=-1):
|
def __init__(self, in_shape, pos, dim=-1):
|
||||||
super(Filter, self).__init__()
|
super(Filter, self).__init__()
|
||||||
@ -210,11 +221,15 @@ class AutoPadToShape(object):
|
|||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
if not torch.is_tensor(x):
|
if not torch.is_tensor(x):
|
||||||
x = torch.as_tensor(x)
|
x = torch.as_tensor(x)
|
||||||
if x.shape[1:] == self.shape:
|
if x.shape[1:] == self.shape or x.shape == self.shape:
|
||||||
return x
|
return x
|
||||||
embedding = torch.zeros((x.shape[0], *self.shape))
|
|
||||||
embedding[:, :x.shape[1], :x.shape[2], :x.shape[3]] = x
|
for i in range(-1, -len(self.shape), -1):
|
||||||
return embedding
|
idx = [0] * len(x.shape)
|
||||||
|
idx[i] = self.shape[i] - x.shape[i]
|
||||||
|
idx = tuple(idx)
|
||||||
|
x = torch.nn.functional.pad(x, idx)
|
||||||
|
return x
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'AutoPadTransform({self.shape})'
|
return f'AutoPadTransform({self.shape})'
|
||||||
@ -233,9 +248,9 @@ class Splitter(nn.Module):
|
|||||||
def __init__(self, in_shape, n, dim=-1):
|
def __init__(self, in_shape, n, dim=-1):
|
||||||
super(Splitter, self).__init__()
|
super(Splitter, self).__init__()
|
||||||
|
|
||||||
self.n = n
|
|
||||||
self.dim = dim
|
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
|
self.n = n
|
||||||
|
self.dim = dim if dim > 0 else len(self.in_shape) - abs(dim)
|
||||||
|
|
||||||
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
self.new_dim_size = (self.in_shape[self.dim] // self.n) + (1 if self.in_shape[self.dim] % self.n != 0 else 0)
|
||||||
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
self._out_shape = tuple([x if self.dim != i else self.new_dim_size for i, x in enumerate(self.in_shape)])
|
||||||
@ -243,22 +258,23 @@ class Splitter(nn.Module):
|
|||||||
self.autopad = AutoPadToShape(self._out_shape)
|
self.autopad = AutoPadToShape(self._out_shape)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
x = x.transpose(0, self.dim)
|
dim = self.dim + 1 if len(self.in_shape) == (x.ndim -1) else self.dim
|
||||||
|
x = x.transpose(0, dim)
|
||||||
n_blocks = list()
|
n_blocks = list()
|
||||||
for block_idx in range(self.n):
|
for block_idx in range(self.n):
|
||||||
start = block_idx * self.new_dim_size
|
start = block_idx * self.new_dim_size
|
||||||
end = (block_idx + 1) * self.new_dim_size
|
end = (block_idx + 1) * self.new_dim_size
|
||||||
block = self.autopad(x[:, :, start:end, :])
|
block = x[start:end].transpose(0, dim)
|
||||||
|
block = self.autopad(block)
|
||||||
n_blocks.append(block.transpose(0, self.dim))
|
n_blocks.append(block)
|
||||||
return n_blocks
|
return n_blocks
|
||||||
|
|
||||||
|
|
||||||
class Merger(nn.Module):
|
class Merger(nn.Module, ShapeMixin):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
y = self.forward([torch.randn(self.in_shape)])
|
y = self.forward([torch.randn(self.in_shape) for _ in range(self.n)])
|
||||||
return y.shape
|
return y.shape
|
||||||
|
|
||||||
def __init__(self, in_shape, n, dim=-1):
|
def __init__(self, in_shape, n, dim=-1):
|
||||||
|
@ -3,7 +3,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||||
from pytorch_lightning.loggers.test_tube import TestTubeLogger
|
# noinspection PyUnresolvedReferences
|
||||||
|
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
|
|
||||||
@ -15,13 +16,13 @@ class Logger(LightningLoggerBase, ABC):
|
|||||||
@property
|
@property
|
||||||
def experiment(self):
|
def experiment(self):
|
||||||
if self.debug:
|
if self.debug:
|
||||||
return self.testtubelogger.experiment
|
return self.csvlogger.experiment
|
||||||
else:
|
else:
|
||||||
return self.neptunelogger.experiment
|
return self.neptunelogger.experiment
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log_dir(self):
|
def log_dir(self):
|
||||||
return Path(self.testtubelogger.experiment.get_logdir()).parent
|
return Path(self.csvlogger.experiment.log_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
@ -64,55 +65,56 @@ class Logger(LightningLoggerBase, ABC):
|
|||||||
self.config.set('project', 'owner', 'testuser')
|
self.config.set('project', 'owner', 'testuser')
|
||||||
self.config.set('project', 'name', 'test')
|
self.config.set('project', 'name', 'test')
|
||||||
self.config.set('project', 'neptune_key', 'XXX')
|
self.config.set('project', 'neptune_key', 'XXX')
|
||||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||||
api_key=self.config.project.neptune_key,
|
api_key=self.config.project.neptune_key,
|
||||||
experiment_name=self.name,
|
experiment_name=self.name,
|
||||||
project_name=self.project_name,
|
project_name=self.project_name,
|
||||||
params=self.config.model_paramters)
|
params=self.config.model_paramters)
|
||||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||||
self.testtubelogger = TestTubeLogger(**self._testtube_kwargs)
|
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||||||
self.log_config_as_ini()
|
self.log_config_as_ini()
|
||||||
|
|
||||||
def log_hyperparams(self, params):
|
def log_hyperparams(self, params):
|
||||||
self.neptunelogger.log_hyperparams(params)
|
self.neptunelogger.log_hyperparams(params)
|
||||||
self.testtubelogger.log_hyperparams(params)
|
self.csvlogger.log_hyperparams(params)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log_metrics(self, metrics, step=None):
|
def log_metrics(self, metrics, step=None):
|
||||||
self.neptunelogger.log_metrics(metrics, step=step)
|
self.neptunelogger.log_metrics(metrics, step=step)
|
||||||
self.testtubelogger.log_metrics(metrics, step=step)
|
self.csvlogger.log_metrics(metrics, step=step)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.testtubelogger.close()
|
self.csvlogger.close()
|
||||||
self.neptunelogger.close()
|
self.neptunelogger.close()
|
||||||
|
|
||||||
def log_config_as_ini(self):
|
def log_config_as_ini(self):
|
||||||
self.config.write(self.log_dir / 'config.ini')
|
self.config.write(self.log_dir / 'config.ini')
|
||||||
|
|
||||||
def log_text(self, name, text, step_nb=0, **kwargs):
|
def log_text(self, name, text, step_nb=0, **_):
|
||||||
# TODO Implement Offline variant.
|
# TODO Implement Offline variant.
|
||||||
self.neptunelogger.log_text(name, text, step_nb)
|
self.neptunelogger.log_text(name, text, step_nb)
|
||||||
|
|
||||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||||
self.testtubelogger.log_metrics(dict(metric_name=metric_value))
|
self.csvlogger.log_metrics(dict(metric_name=metric_value))
|
||||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||||
|
|
||||||
def log_image(self, name, image, ext='png', **kwargs):
|
def log_image(self, name, image, ext='png', **kwargs):
|
||||||
self.neptunelogger.log_image(name, image, **kwargs)
|
|
||||||
step = kwargs.get('step', None)
|
step = kwargs.get('step', None)
|
||||||
name = f'{step}_{name}' if step is not None else name
|
image_name = f'{step}_{name}' if step is not None else name
|
||||||
name = f'{name}.{ext[1:] if ext.startswith(".") else ext}'
|
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||||
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
||||||
image.savefig(self.log_dir / self.media_dir / name)
|
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
|
||||||
|
self.neptunelogger.log_image(name, str(image_path), **kwargs)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
self.testtubelogger.save()
|
self.csvlogger.save()
|
||||||
self.neptunelogger.save()
|
self.neptunelogger.save()
|
||||||
|
|
||||||
def finalize(self, status):
|
def finalize(self, status):
|
||||||
self.testtubelogger.finalize(status)
|
self.csvlogger.finalize(status)
|
||||||
self.neptunelogger.finalize(status)
|
self.neptunelogger.finalize(status)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -20,7 +20,7 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
|
|
||||||
paramter_mapping.update(
|
paramter_mapping.update(
|
||||||
dict(
|
dict(
|
||||||
activation=self._activations[self['activation']]
|
activation=self.__getattribute__('activation')
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
|
|
||||||
def __getattribute__(self, name):
|
def __getattribute__(self, name):
|
||||||
if name == 'activation':
|
if name == 'activation':
|
||||||
return self._activations[self['activation']]
|
return self._activations[self['activation'].lower()]
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return super(ModelParameters, self).__getattribute__(name)
|
return super(ModelParameters, self).__getattribute__(name)
|
||||||
@ -56,6 +56,7 @@ class ModelParameters(Namespace, Mapping):
|
|||||||
|
|
||||||
_activations = dict(
|
_activations = dict(
|
||||||
leaky_relu=nn.LeakyReLU,
|
leaky_relu=nn.LeakyReLU,
|
||||||
|
elu=nn.ELU,
|
||||||
relu=nn.ReLU,
|
relu=nn.ReLU,
|
||||||
sigmoid=nn.Sigmoid,
|
sigmoid=nn.Sigmoid,
|
||||||
tanh=nn.Tanh
|
tanh=nn.Tanh
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
try:
|
try:
|
||||||
import matplotlib.pyplot as plt
|
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||||
except ImportError: # pragma: no-cover
|
except ImportError: # pragma: no-cover
|
||||||
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
|
raise ImportError('You want to use `matplotlib` plugins which are not installed yet,' # pragma: no-cover
|
||||||
' install it with `pip install matplotlib`.')
|
' install it with `pip install matplotlib`.')
|
||||||
@ -8,30 +8,23 @@ from pathlib import Path
|
|||||||
|
|
||||||
|
|
||||||
class Plotter(object):
|
class Plotter(object):
|
||||||
|
|
||||||
def __init__(self, root_path=''):
|
def __init__(self, root_path=''):
|
||||||
if not root_path:
|
if not root_path:
|
||||||
self.root_path = Path(root_path)
|
self.root_path = Path(root_path)
|
||||||
|
|
||||||
def save_current_figure(self, filename: str, extention='.png', naked=False):
|
def save_figure(self, figure, title, extention='.png', naked=False):
|
||||||
fig, _ = plt.gcf(), plt.gca()
|
canvas = FigureCanvas(figure)
|
||||||
# Prepare save location and check img file extention
|
# Prepare save location and check img file extention
|
||||||
path = self.root_path / Path(filename if filename.endswith(extention) else f'{filename}{extention}')
|
path = self.root_path / f'{title}{extention}'
|
||||||
path.parent.mkdir(exist_ok=True, parents=True)
|
path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
if naked:
|
if naked:
|
||||||
plt.axis('off')
|
figure.axis('off)')
|
||||||
fig.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
|
figure.savefig(path, bbox_inches='tight', transparent=True, pad_inches=0)
|
||||||
fig.clf()
|
canvas.print_figure(path)
|
||||||
else:
|
else:
|
||||||
fig.savefig(path)
|
canvas.print_figure(path)
|
||||||
fig.clf()
|
|
||||||
|
|
||||||
def show_current_figure(self):
|
|
||||||
fig, _ = plt.gcf(), plt.gca()
|
|
||||||
fig.show()
|
|
||||||
fig.clf()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
output_root = Path('..') / 'output'
|
raise PermissionError('Get out of here.')
|
||||||
p = Plotter(output_root)
|
|
||||||
p.save_current_figure('test.png')
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user