Transformer running
This commit is contained in:
@@ -25,5 +25,12 @@ class _BaseDataModule(LightningDataModule):
|
||||
self.datasets = dict()
|
||||
|
||||
def transfer_batch_to_device(self, batch, device):
|
||||
return batch.to(device)
|
||||
|
||||
if isinstance(batch, list):
|
||||
for idx, item in enumerate(batch):
|
||||
try:
|
||||
batch[idx] = item.to(device)
|
||||
except (AttributeError, RuntimeError):
|
||||
continue
|
||||
return batch
|
||||
else:
|
||||
return batch.to(device)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import ast
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Dict
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from abc import ABC
|
||||
@@ -9,8 +12,67 @@ from argparse import Namespace, ArgumentParser
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser, DuplicateSectionError
|
||||
import hashlib
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from ml_lib.utils.tools import locate_and_import_class
|
||||
from ml_lib.utils.loggers import Logger
|
||||
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||
|
||||
|
||||
# Argument Parser and default Values
|
||||
# =============================================================================
|
||||
def parse_comandline_args_add_defaults(filepath, overrides=None):
|
||||
|
||||
# Parse Command Line
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str)
|
||||
parser.add_argument('--data_name', type=str)
|
||||
|
||||
# Load Defaults from _parameters.ini file
|
||||
config = configparser.ConfigParser()
|
||||
config.read(str(filepath))
|
||||
|
||||
new_defaults = dict()
|
||||
for key in ['project', 'train', 'data']:
|
||||
defaults = config[key]
|
||||
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
|
||||
|
||||
if new_defaults['debug']:
|
||||
new_defaults.update(
|
||||
max_epochs=2,
|
||||
max_steps=2 # The seems to be the new "fast_dev_run"
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
overrides = overrides or dict()
|
||||
default_data = overrides.get('data_name', None) or new_defaults['data_name']
|
||||
default_model = overrides.get('model_name', None) or new_defaults['model_name']
|
||||
|
||||
data_name = args.__dict__.get('data_name', None) or default_data
|
||||
model_name = args.__dict__.get('model_name', None) or default_model
|
||||
|
||||
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
|
||||
|
||||
found_data_class = locate_and_import_class(data_name, 'datasets')
|
||||
found_model_class = locate_and_import_class(model_name, 'models')
|
||||
|
||||
for module in [Logger, Trainer, found_data_class, found_model_class]:
|
||||
parser = module.add_argparse_args(parser)
|
||||
|
||||
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
||||
|
||||
args = vars(args)
|
||||
args.update({key: auto_cast(val) for key, val in args.items()})
|
||||
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
|
||||
row_log_interval=1000, # TODO: Better Value / Setting
|
||||
log_save_interval=10000, # TODO: Better Value / Setting
|
||||
auto_lr_find=not args['debug'],
|
||||
weights_summary='top',
|
||||
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1)
|
||||
)
|
||||
|
||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||
args.update(**overrides)
|
||||
return args, found_data_class, found_model_class
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
|
||||
30
utils/equal_sampler.py
Normal file
30
utils/equal_sampler.py
Normal file
@@ -0,0 +1,30 @@
|
||||
|
||||
import random
|
||||
from typing import Iterator, Sequence
|
||||
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data.sampler import T_co
|
||||
|
||||
|
||||
# noinspection PyMissingConstructor
|
||||
class EqualSampler(Sampler):
|
||||
|
||||
def __init__(self, idxs_per_class: Sequence[Sequence[float]], replacement: bool = True) -> None:
|
||||
|
||||
self.replacement = replacement
|
||||
self.idxs_per_class = idxs_per_class
|
||||
self.len_largest_class = max([len(x) for x in self.idxs_per_class])
|
||||
|
||||
def __iter__(self) -> Iterator[T_co]:
|
||||
return iter(random.choice(self.idxs_per_class[random.randint(0, len(self.idxs_per_class)-1)])
|
||||
for _ in range(len(self)))
|
||||
|
||||
def __len__(self):
|
||||
return self.len_largest_class * len(self.idxs_per_class)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
es = EqualSampler([list(range(5)), list(range(5, 10)), list(range(10, 12))])
|
||||
for i in es:
|
||||
print(i)
|
||||
pass
|
||||
@@ -1,5 +1,6 @@
|
||||
import inspect
|
||||
from argparse import ArgumentParser
|
||||
from copy import deepcopy
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
@@ -17,11 +18,34 @@ class Logger(LightningLoggerBase):
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
return argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
cleaned_args = deepcopy(args.__dict__)
|
||||
|
||||
# Clean Seed and other attributes
|
||||
# TODO: Find a better way in cleaning this
|
||||
for attr in ['seed', 'num_worker', 'debug', 'eval', 'owner', 'data_root', 'check_val_every_n_epoch',
|
||||
'reset', 'outpath', 'version', 'gpus', 'neptune_key', 'num_sanity_val_steps', 'tpu_cores',
|
||||
'progress_bar_refresh_rate', 'log_save_interval', 'row_log_interval']:
|
||||
|
||||
try:
|
||||
del cleaned_args[attr]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
kwargs.update(params=cleaned_args)
|
||||
new_logger = argparse_utils.from_argparse_args(cls, args, **kwargs)
|
||||
return new_logger
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
def fingerprint(self):
|
||||
h = hashlib.md5()
|
||||
h.update(self._finger_print_string.encode())
|
||||
fingerprint = h.hexdigest()
|
||||
return fingerprint
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
short_name = "".join(c for c in self.model_name if c.isupper())
|
||||
return f'{short_name}_{self.fingerprint}'
|
||||
|
||||
media_dir = 'media'
|
||||
|
||||
@@ -42,7 +66,12 @@ class Logger(LightningLoggerBase):
|
||||
|
||||
@property
|
||||
def project_name(self):
|
||||
return f"{self.owner}/{self.name.replace('_', '-')}"
|
||||
return f"{self.owner}/{self.projeect_root.replace('_', '-')}"
|
||||
|
||||
@property
|
||||
def projeect_root(self):
|
||||
root_path = Path(os.getcwd()).name if not self.debug else 'test'
|
||||
return root_path
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
@@ -56,7 +85,7 @@ class Logger(LightningLoggerBase):
|
||||
def outpath(self):
|
||||
return Path(self.root_out) / self.model_name
|
||||
|
||||
def __init__(self, owner, neptune_key, model_name, project_name='', outpath='output', seed=69, debug=False):
|
||||
def __init__(self, owner, neptune_key, model_name, outpath='output', seed=69, debug=False, params=None):
|
||||
"""
|
||||
params (dict|None): Optional. Parameters of the experiment. After experiment creation params are read-only.
|
||||
Parameters are displayed in the experiment’s Parameters section and each key-value pair can be
|
||||
@@ -71,51 +100,67 @@ class Logger(LightningLoggerBase):
|
||||
super(Logger, self).__init__()
|
||||
|
||||
self.debug = debug
|
||||
self._name = project_name or Path(os.getcwd()).name if not self.debug else 'test'
|
||||
self.owner = owner if not self.debug else 'testuser'
|
||||
self.neptune_key = neptune_key if not self.debug else 'XXX'
|
||||
self.root_out = outpath if not self.debug else 'debug_out'
|
||||
|
||||
self.params = params
|
||||
|
||||
self.seed = seed
|
||||
self.model_name = model_name
|
||||
|
||||
if self.params:
|
||||
_, fingerprint_tuple = zip(*sorted(self.params.items(), key=lambda tup: tup[0]))
|
||||
self._finger_print_string = str(fingerprint_tuple)
|
||||
else:
|
||||
self._finger_print_string = str((self.owner, self.root_out, self.seed, self.model_name, self.debug))
|
||||
self.params.update(fingerprint=self.fingerprint)
|
||||
|
||||
self._csvlogger_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=self.debug,
|
||||
params=self.params,
|
||||
api_key=self.neptune_key,
|
||||
experiment_name=self.name,
|
||||
# tags=?,
|
||||
project_name=self.project_name)
|
||||
try:
|
||||
self.neptunelogger = NeptuneLogger(**self._neptune_kwargs)
|
||||
except ProjectNotFound as e:
|
||||
print(f'The project "{self.project_name}"')
|
||||
print(f'The project "{self.project_name}" does not exist! Create it or check your spelling.')
|
||||
print(e)
|
||||
|
||||
self.csvlogger = CSVLogger(**self._csvlogger_kwargs)
|
||||
if self.params:
|
||||
self.log_hyperparams(self.params)
|
||||
|
||||
def close(self):
|
||||
self.csvlogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def set_fingerprint_string(self, fingerprint_str):
|
||||
self._finger_print_string = fingerprint_str
|
||||
|
||||
def log_text(self, name, text, **_):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text)
|
||||
|
||||
def log_hyperparams(self, params):
|
||||
self.neptunelogger.log_hyperparams(params)
|
||||
self.csvlogger.log_hyperparams(params)
|
||||
pass
|
||||
|
||||
def log_metric(self, metric_name, metric_value, step=None, **kwargs):
|
||||
self.csvlogger.log_metrics(dict(metric_name=metric_value, **kwargs), step=step, **kwargs)
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, step=step, **kwargs)
|
||||
pass
|
||||
|
||||
def log_metrics(self, metrics, step=None):
|
||||
self.neptunelogger.log_metrics(metrics, step=step)
|
||||
self.csvlogger.log_metrics(metrics, step=step)
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.csvlogger.close()
|
||||
self.neptunelogger.close()
|
||||
|
||||
def log_text(self, name, text, **_):
|
||||
# TODO Implement Offline variant.
|
||||
self.neptunelogger.log_text(name, text)
|
||||
|
||||
def log_metric(self, metric_name, metric_value, **kwargs):
|
||||
self.csvlogger.log_metrics(dict(metric_name=metric_value))
|
||||
self.neptunelogger.log_metric(metric_name, metric_value, **kwargs)
|
||||
|
||||
def log_image(self, name, image, ext='png', **kwargs):
|
||||
step = kwargs.get('step', None)
|
||||
image_name = f'{step}_{name}' if step is not None else name
|
||||
def log_image(self, name, image, ext='png', step=None, **kwargs):
|
||||
image_name = f'{"0" * (4 - len(str(step)))}{step}_{name}' if step is not None else name
|
||||
image_path = self.log_dir / self.media_dir / f'{image_name}.{ext[1:] if ext.startswith(".") else ext}'
|
||||
(self.log_dir / self.media_dir).mkdir(parents=True, exist_ok=True)
|
||||
image.savefig(image_path, bbox_inches='tight', pad_inches=0)
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import inspect
|
||||
import pickle
|
||||
import shelve
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, ArgumentError
|
||||
from ast import literal_eval
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Union
|
||||
@@ -70,14 +70,17 @@ def add_argparse_args(cls, parent_parser):
|
||||
full_arg_spec = inspect.getfullargspec(cls.__init__)
|
||||
n_non_defaults = len(full_arg_spec.args) - (len(full_arg_spec.defaults) if full_arg_spec.defaults else 0)
|
||||
for idx, argument in enumerate(full_arg_spec.args):
|
||||
if argument == 'self':
|
||||
try:
|
||||
if argument == 'self':
|
||||
continue
|
||||
if idx < n_non_defaults:
|
||||
parser.add_argument(f'--{argument}', type=int)
|
||||
else:
|
||||
argument_type = type(argument)
|
||||
parser.add_argument(f'--{argument}',
|
||||
type=argument_type,
|
||||
default=full_arg_spec.defaults[idx - n_non_defaults]
|
||||
)
|
||||
except ArgumentError:
|
||||
continue
|
||||
if idx < n_non_defaults:
|
||||
parser.add_argument(f'--{argument}', type=int)
|
||||
else:
|
||||
argument_type = type(argument)
|
||||
parser.add_argument(f'--{argument}',
|
||||
type=argument_type,
|
||||
default=full_arg_spec.defaults[idx - n_non_defaults]
|
||||
)
|
||||
return parser
|
||||
|
||||
Reference in New Issue
Block a user