bringing brances up to date
This commit is contained in:
29
utils/_basedatamodule.py
Normal file
29
utils/_basedatamodule.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pytorch_lightning import LightningDataModule
|
||||
|
||||
|
||||
# Dataset Options
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
DATA_OPTION_test = 'test'
|
||||
DATA_OPTION_devel = 'devel'
|
||||
DATA_OPTION_train = 'train'
|
||||
DATA_OPTIONS = [DATA_OPTION_train, DATA_OPTION_devel, DATA_OPTION_test]
|
||||
|
||||
|
||||
class _BaseDataModule(LightningDataModule):
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.datasets[DATA_OPTION_train].sample_shape
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.datasets = dict()
|
||||
|
||||
def transfer_batch_to_device(self, batch, device):
|
||||
return batch.to(device)
|
||||
|
@@ -1,19 +1,34 @@
|
||||
from abc import ABC
|
||||
import inspect
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.neptune import NeptuneLogger
|
||||
from neptune.api_exceptions import ProjectNotFound
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
||||
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
||||
from pytorch_lightning.utilities import argparse_utils
|
||||
|
||||
from .config import Config
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
|
||||
class Logger(LightningLoggerBase, ABC):
|
||||
class Logger(LightningLoggerBase):
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
return argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
media_dir = 'media'
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(cls, parent_parser)
|
||||
|
||||
@property
|
||||
def experiment(self):
|
||||
if self.debug:
|
||||
@@ -25,27 +40,23 @@ class Logger(LightningLoggerBase, ABC):
|
||||
def log_dir(self):
|
||||
return Path(self.csvlogger.experiment.log_dir)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.config.name
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return f"{self.config.project.owner}/{self.config.project.name.replace('_', '-')}"
|
||||
return f"{self.owner}/{self.name.replace('_', '-')}"
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self.config.get('main', 'seed')
|
||||
return self.seed
|
||||
|
||||
@property
|
||||
def save_dir(self):
|
||||
return self.log_dir
|
||||
|
||||
@property
|
||||
def outpath(self):
|
||||
return Path(self.config.train.outpath) / self.config.model.type
|
||||
return Path(self.root_out) / self.model_name
|
||||
|
||||
@property
|
||||
def exp_path(self):
|
||||
return Path(self.outpath) / self.name
|
||||
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False):
|
||||
"""
|
||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||
@@ -59,19 +70,19 @@ class Logger(LightningLoggerBase, ABC):
|
||||
"""
|
||||
super(Logger, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.debug = self.config.main.debug
|
||||
if self.debug:
|
||||
self.config.add_section('project')
|
||||
self.config.set('project', 'owner', 'testuser')
|
||||
self.config.set('project', 'name', 'test')
|
||||
self.config.set('project', 'neptune_key', 'XXX')
|
||||
self.debug = debug
|
||||
self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
|
||||
self.owner = owner if not self.debug else 'testuser'
|
||||
self.neptune_key = neptune_key if not self.debug else 'XXX'
|
||||
self.root_out = outpath if not self.debug else 'debug_out'
|
||||
self.seed = seed
|
||||
self.model_name = model_name
|
||||
|
||||
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
api_key=self.neptune_key,
|
||||
experiment_name=self.name,
|
||||
project_name=self.project_name,
|
||||
params=self.config.model_paramters)
|
||||
project_name=self.project_name)
|
||||
try:
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
except ProjectNotFound as e:
|
||||
@@ -79,7 +90,6 @@ class Logger(LightningLoggerBase, ABC):
|
||||
print(e)
|
||||
|
||||
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||||
self.log_config_as_ini()
|
||||
|
||||
def log_hyperparams(self, params):
|
||||
self.neptunelogger.log_hyperparams(params)
|
||||
@@ -95,19 +105,15 @@ class Logger(LightningLoggerBase, ABC):
|
||||
self.csvlogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_config_as_ini(self):
|
||||
self.config.write(self.log_dir / 'config.ini')
|
||||
|
||||
def log_text(self, name, text, step_nb=0, **_):
|
||||
def log_text(self, name, text, **_):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text, step_nb)
|
||||
self.neptunelogger.log_text(name, text)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.csvlogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, ext='png', **kwargs):
|
||||
|
||||
step = kwargs.get('step', None)
|
||||
image_name = f'{step}_{name}' if step is not None else name
|
||||
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||
|
@@ -13,6 +13,10 @@ from torch import nn
|
||||
# Hyperparamter Object
|
||||
class ModelParameters(Namespace, Mapping):
|
||||
|
||||
@property
|
||||
def as_dict(self):
|
||||
return {key: self.get(key) if key != 'activation' else self.activation_as_string for key in self.keys()}
|
||||
|
||||
@property
|
||||
def activation_as_string(self):
|
||||
return self['activation'].lower()
|
||||
@@ -50,13 +54,7 @@ class ModelParameters(Namespace, Mapping):
|
||||
if name == 'activation':
|
||||
return self._activations[self['activation'].lower()]
|
||||
else:
|
||||
try:
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
except AttributeError as e:
|
||||
if name == 'stretch':
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
return super(ModelParameters, self).__getattribute__(name)
|
||||
|
||||
_activations = dict(
|
||||
leaky_relu=nn.LeakyReLU,
|
||||
@@ -88,16 +86,20 @@ class SavedLightningModels(object):
|
||||
model = torch.load(models_root_path / 'model_class.obj')
|
||||
assert model is not None
|
||||
|
||||
return cls(weights=str(checkpoint_path), model=model)
|
||||
return cls(weights=checkpoint_path, model=model)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.weights: str = kwargs.get('weights', '')
|
||||
self.weights: Path = Path(kwargs.get('weights', ''))
|
||||
self.hparams: Path = self.weights.parent / 'hparams.yaml'
|
||||
|
||||
self.model = kwargs.get('model', None)
|
||||
assert self.model is not None
|
||||
|
||||
def restore(self):
|
||||
pretrained_model = self.model.load_from_checkpoint(self.weights)
|
||||
|
||||
pretrained_model = self.model.load_from_checkpoint(self.weights.__str__())
|
||||
# , hparams_file=self.hparams.__str__())
|
||||
pretrained_model.eval()
|
||||
pretrained_model.freeze()
|
||||
return pretrained_model
|
||||
return pretrained_model
|
||||
|
||||
|
@@ -1,6 +1,9 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pickle
|
||||
import shelve
|
||||
from argparse import ArgumentParser
|
||||
from ast import literal_eval
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Union
|
||||
|
||||
@@ -9,6 +12,13 @@ import torch
|
||||
import random
|
||||
|
||||
|
||||
def auto_cast(a):
|
||||
try:
|
||||
return literal_eval(a)
|
||||
except:
|
||||
return a
|
||||
|
||||
|
||||
def to_one_hot(idx_array, max_classes):
|
||||
one_hot = np.zeros((idx_array.size, max_classes))
|
||||
one_hot[np.arange(idx_array.size), idx_array] = 1
|
||||
@@ -54,3 +64,20 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
continue
|
||||
raise AttributeError(f'Check the Model name. Possible model files are:\n{[x.name for x in module_paths]}')
|
||||
|
||||
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
full_arg_spec = inspect.getfullargspec(cls.__init__)
|
||||
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
|
||||
for idx, argument in enumerate(full_arg_spec.args):
|
||||
if argument == 'self':
|
||||
continue
|
||||
if idx < n_non_defaults:
|
||||
parser.add_argument(f'--{argument}', type=int)
|
||||
else:
|
||||
argument_type = type(argument)
|
||||
parser.add_argument(f'--{argument}',
|
||||
type=argument_type,
|
||||
default=full_arg_spec.defaults[idx - n_non_defaults]
|
||||
)
|
||||
return parser
|
||||
|
Reference in New Issue
Block a user