bringing brances up to date

This commit is contained in:
Steffen Illium
2021-02-15 11:39:54 +01:00
parent 010176e80b
commit a966321576
11 changed files with 216 additions and 197 deletions

29
utils/_basedatamodule.py Normal file
View File

@@ -0,0 +1,29 @@
from pytorch_lightning import LightningDataModule
# Dataset Options
from ml_lib.utils.tools import add_argparse_args
DATA_OPTION_test = 'test'
DATA_OPTION_devel = 'devel'
DATA_OPTION_train = 'train'
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
class _BaseDataModule(LightningDataModule):
@property
def shape(self):
return self.datasets[DATA_OPTION_train].sample_shape
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.datasets = dict()
def transfer_batch_to_device(self, batch, device):
return batch.to(device)

View File

@@ -1,19 +1,34 @@
from abc import ABC
import inspect
from argparse import ArgumentParser
from pathlib import Path
import os
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.loggers.neptune import NeptuneLogger
from neptune.api_exceptions import ProjectNotFound
# noinspection PyUnresolvedReferences
from pytorch_lightning.loggers.csv_logs import CSVLogger
from pytorch_lightning.utilities import argparse_utils
from .config import Config
from ml_lib.utils.tools import add_argparse_args
class Logger(LightningLoggerBase, ABC):
class Logger(LightningLoggerBase):
@classmethod
def from_argparse_args(cls, args, **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
@property
def name(self) -> str:
return self._name
media_dir = 'media'
@classmethod
def add_argparse_args(cls, parent_parser):
return add_argparse_args(cls, parent_parser)
@property
def experiment(self):
if self.debug:
@@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC):
def log_dir(self):
return Path(self.csvlogger.experiment.log_dir)
@property
def name(self):
return self.config.name
@property
def project_name(self):
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
return f"{self.owner}/{self.name.replace('_', '-')}"
@property
def version(self):
return self.config.get('main', 'seed')
return self.seed
@property
def save_dir(self):
return self.log_dir
@property
def outpath(self):
return Path(self.config.train.outpath) / self.config.model.type
return Path(self.root_out) / self.model_name
@property
def exp_path(self):
return Path(self.outpath) / self.name
def __init__(self, config: Config):
def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False):
"""
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be
@@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC):
"""
super(Logger, self).__init__()
self.config = config
self.debug = self.config.main.debug
if self.debug:
self.config.add_section('project')
self.config.set('project', 'owner', 'testuser')
self.config.set('project', 'name', 'test')
self.config.set('project', 'neptune_key', 'XXX')
self.debug = debug
self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
self.owner = owner if not self.debug else 'testuser'
self.neptune_key = neptune_key if not self.debug else 'XXX'
self.root_out = outpath if not self.debug else 'debug_out'
self.seed = seed
self.model_name = model_name
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
api_key=self.config.project.neptune_key,
api_key=self.neptune_key,
experiment_name=self.name,
project_name=self.project_name,
params=self.config.model_paramters)
project_name=self.project_name)
try:
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
except ProjectNotFound as e:
@@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC):
print(e)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
self.log_config_as_ini()
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
@@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC):
self.csvlogger.close()
self.neptunelogger.close()
def log_config_as_ini(self):
self.config.write(self.log_dir / 'config.ini')
def log_text(self, name, text, step_nb=0, **_):
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text, step_nb)
self.neptunelogger.log_text(name, text)
def log_metric(self, metric_name, metric_value, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, ext='png', **kwargs):
step = kwargs.get('step', None)
image_name = f'{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'

View File

@@ -13,6 +13,10 @@ from torch import nn
# Hyperparamter Object
class ModelParameters(Namespace, Mapping):
@property
def as_dict(self):
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
@property
def activation_as_string(self):
return self['activation'].lower()
@@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping):
if name == 'activation':
return self._activations[self['activation'].lower()]
else:
try:
return super(ModelParameters, self).__getattribute__(name)
except AttributeError as e:
if name == 'stretch':
return False
else:
return None
return super(ModelParameters, self).__getattribute__(name)
_activations = dict(
leaky_relu=nn.LeakyReLU,
@@ -88,16 +86,20 @@ class SavedLightningModels(object):
model = torch.load(models_root_path / 'model_class.obj')
assert model is not None
return cls(weights=str(checkpoint_path), model=model)
return cls(weights=checkpoint_path, model=model)
def __init__(self, **kwargs):
self.weights: str = kwargs.get('weights', '')
self.weights: Path = Path(kwargs.get('weights', ''))
self.hparams: Path = self.weights.parent / 'hparams.yaml'
self.model = kwargs.get('model', None)
assert self.model is not None
def restore(self):
pretrained_model = self.model.load_from_checkpoint(self.weights)
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
# , hparams_file=self.hparams.__str__())
pretrained_model.eval()
pretrained_model.freeze()
return pretrained_model
return pretrained_model

View File

@@ -1,6 +1,9 @@
import importlib
import inspect
import pickle
import shelve
from argparse import ArgumentParser
from ast import literal_eval
from pathlib import Path, PurePath
from typing import Union
@@ -9,6 +12,13 @@ import torch
import random
def auto_cast(a):
try:
return literal_eval(a)
except:
return a
def to_one_hot(idx_array, max_classes):
one_hot = np.zeros((idx_array.size, max_classes))
one_hot[np.arange(idx_array.size), idx_array] = 1
@@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
continue
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
def add_argparse_args(cls, parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
full_arg_spec = inspect.getfullargspec(cls.__init__)
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
for idx, argument in enumerate(full_arg_spec.args):
if argument == 'self':
continue
if idx < n_non_defaults:
parser.add_argument(f'--{argument}', type=int)
else:
argument_type = type(argument)
parser.add_argument(f'--{argument}',
type=argument_type,
default=full_arg_spec.defaults[idx - n_non_defaults]
)
return parser