Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:08 +01:00
parent b5e3e5aec1
commit f89f0f8528
14 changed files with 349 additions and 80 deletions

View File

@@ -25,5 +25,12 @@ class _BaseDataModule(LightningDataModule):
self.datasets = dict()
def transfer_batch_to_device(self, batch, device):
return batch.to(device)
if isinstance(batch, list):
for idx, item in enumerate(batch):
try:
batch[idx] = item.to(device)
except (AttributeError, RuntimeError):
continue
return batch
else:
return batch.to(device)

View File

@@ -1,6 +1,9 @@
import ast
import configparser
from pathlib import Path
from typing import Mapping, Dict
import torch
from copy import deepcopy
from abc import ABC
@@ -9,8 +12,67 @@ from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser, DuplicateSectionError
import hashlib
from pytorch_lightning import Trainer
from ml_lib.utils.tools import locate_and_import_class
from ml_lib.utils.loggers import Logger
from ml_lib.utils.tools import locate_and_import_class, auto_cast
# Argument Parser and default Values
# =============================================================================
def parse_comandline_args_add_defaults(filepath, overrides=None):
# Parse Command Line
parser = ArgumentParser()
parser.add_argument('--model_name', type=str)
parser.add_argument('--data_name', type=str)
# Load Defaults from _parameters.ini file
config = configparser.ConfigParser()
config.read(str(filepath))
new_defaults = dict()
for key in ['project', 'train', 'data']:
defaults = config[key]
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
if new_defaults['debug']:
new_defaults.update(
max_epochs=2,
max_steps=2 # The seems to be the new "fast_dev_run"
)
args, _ = parser.parse_known_args()
overrides = overrides or dict()
default_data = overrides.get('data_name', None) or new_defaults['data_name']
default_model = overrides.get('model_name', None) or new_defaults['model_name']
data_name = args.__dict__.get('data_name', None) or default_data
model_name = args.__dict__.get('model_name', None) or default_model
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
found_data_class = locate_and_import_class(data_name, 'datasets')
found_model_class = locate_and_import_class(model_name, 'models')
for module in [Logger, Trainer, found_data_class, found_model_class]:
parser = module.add_argparse_args(parser)
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
args = vars(args)
args.update({key: auto_cast(val) for key, val in args.items()})
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
row_log_interval=1000, # TODO: Better Value / Setting
log_save_interval=10000, # TODO: Better Value / Setting
auto_lr_find=not args['debug'],
weights_summary='top',
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1)
)
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
args.update(**overrides)
return args, found_data_class, found_model_class
def is_jsonable(x):

30
utils/equal_sampler.py Normal file
View File

@@ -0,0 +1,30 @@
import random
from typing import Iterator, Sequence
from torch.utils.data import Sampler
from torch.utils.data.sampler import T_co
# noinspection PyMissingConstructor
class EqualSampler(Sampler):
def __init__(self, idxs_per_class: Sequence[Sequence[float]], replacement: bool = True) -> None:
self.replacement = replacement
self.idxs_per_class = idxs_per_class
self.len_largest_class = max([len(x) for x in self.idxs_per_class])
def __iter__(self) -> Iterator[T_co]:
return iter(random.choice(self.idxs_per_class[random.randint(0, len(self.idxs_per_class)-1)])
for _ in range(len(self)))
def __len__(self):
return self.len_largest_class * len(self.idxs_per_class)
if __name__ == '__main__':
es = EqualSampler([list(range(5)), list(range(5, 10)), list(range(10, 12))])
for i in es:
print(i)
pass

View File

@@ -1,5 +1,6 @@
import inspect
from argparse import ArgumentParser
from copy import deepcopy
import hashlib
from pathlib import Path
import os
@@ -17,11 +18,34 @@ class Logger(LightningLoggerBase):
@classmethod
def from_argparse_args(cls, args, **kwargs):
return argparse_utils.from_argparse_args(cls, args, **kwargs)
cleaned_args = deepcopy(args.__dict__)
# Clean Seed and other attributes
# TODO: Find a better way in cleaning this
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
try:
del cleaned_args[attr]
except KeyError:
pass
kwargs.update(params=cleaned_args)
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
return new_logger
@property
def name(self) -> str:
return self._name
def fingerprint(self):
h = hashlib.md5()
h.update(self._finger_print_string.encode())
fingerprint = h.hexdigest()
return fingerprint
@property
def name(self):
short_name = "".join(c for c in self.model_name if c.isupper())
return f'{short_name}_{self.fingerprint}'
media_dir = 'media'
@@ -42,7 +66,12 @@ class Logger(LightningLoggerBase):
@property
def project_name(self):
return f"{self.owner}/{self.name.replace('_', '-')}"
return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
@property
def projeect_root(self):
root_path = Path(os.getcwd()).name if not self.debug else 'test'
return root_path
@property
def version(self):
@@ -56,7 +85,7 @@ class Logger(LightningLoggerBase):
def outpath(self):
return Path(self.root_out) / self.model_name
def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False):
def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
"""
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and each key-value pair can be
@@ -71,51 +100,67 @@ class Logger(LightningLoggerBase):
super(Logger, self).__init__()
self.debug = debug
self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
self.owner = owner if not self.debug else 'testuser'
self.neptune_key = neptune_key if not self.debug else 'XXX'
self.root_out = outpath if not self.debug else 'debug_out'
self.params = params
self.seed = seed
self.model_name = model_name
if self.params:
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
self._finger_print_string = str(fingerprint_tuple)
else:
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
self.params.update(fingerprint=self.fingerprint)
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
self._neptune_kwargs = dict(offline_mode=self.debug,
params=self.params,
api_key=self.neptune_key,
experiment_name=self.name,
# tags=?,
project_name=self.project_name)
try:
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
except ProjectNotFound as e:
print(f'The project "{self.project_name}"')
print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
print(e)
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
if self.params:
self.log_hyperparams(self.params)
def close(self):
self.csvlogger.close()
self.neptunelogger.close()
def set_fingerprint_string(self, fingerprint_str):
self._finger_print_string = fingerprint_str
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text)
def log_hyperparams(self, params):
self.neptunelogger.log_hyperparams(params)
self.csvlogger.log_hyperparams(params)
pass
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
pass
def log_metrics(self, metrics, step=None):
self.neptunelogger.log_metrics(metrics, step=step)
self.csvlogger.log_metrics(metrics, step=step)
pass
def close(self):
self.csvlogger.close()
self.neptunelogger.close()
def log_text(self, name, text, **_):
# TODO Implement Offline variant.
self.neptunelogger.log_text(name, text)
def log_metric(self, metric_name, metric_value, **kwargs):
self.csvlogger.log_metrics(dict(metric_name=metric_value))
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
def log_image(self, name, image, ext='png', **kwargs):
step = kwargs.get('step', None)
image_name = f'{step}_{name}' if step is not None else name
def log_image(self, name, image, ext='png', step=None, **kwargs):
image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
image.savefig(image_path, bbox_inches='tight', pad_inches=0)

View File

@@ -2,7 +2,7 @@ import importlib
import inspect
import pickle
import shelve
from argparse import ArgumentParser
from argparse import ArgumentParser, ArgumentError
from ast import literal_eval
from pathlib import Path, PurePath
from typing import Union
@@ -70,14 +70,17 @@ def add_argparse_args(cls, parent_parser):
full_arg_spec = inspect.getfullargspec(cls.__init__)
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
for idx, argument in enumerate(full_arg_spec.args):
if argument == 'self':
try:
if argument == 'self':
continue
if idx < n_non_defaults:
parser.add_argument(f'--{argument}', type=int)
else:
argument_type = type(argument)
parser.add_argument(f'--{argument}',
type=argument_type,
default=full_arg_spec.defaults[idx - n_non_defaults]
)
except ArgumentError:
continue
if idx < n_non_defaults:
parser.add_argument(f'--{argument}', type=int)
else:
argument_type = type(argument)
parser.add_argument(f'--{argument}',
type=argument_type,
default=full_arg_spec.defaults[idx - n_non_defaults]
)
return parser