Transformer running

This commit is contained in:
Steffen Illium
2021-03-04 12:01:08 +01:00
parent b5e3e5aec1
commit f89f0f8528
14 changed files with 349 additions and 80 deletions

View File

@ -1,6 +1,9 @@
import ast
import configparser
from pathlib import Path
from typing import Mapping, Dict
import torch
from copy import deepcopy
from abc import ABC
@ -9,8 +12,67 @@ from argparse import Namespace, ArgumentParser
from collections import defaultdict
from configparser import ConfigParser, DuplicateSectionError
import hashlib
from pytorch_lightning import Trainer
from ml_lib.utils.tools import locate_and_import_class
from ml_lib.utils.loggers import Logger
from ml_lib.utils.tools import locate_and_import_class, auto_cast
# Argument Parser and default Values
# =============================================================================
def parse_comandline_args_add_defaults(filepath, overrides=None):
# Parse Command Line
parser = ArgumentParser()
parser.add_argument('--model_name', type=str)
parser.add_argument('--data_name', type=str)
# Load Defaults from _parameters.ini file
config = configparser.ConfigParser()
config.read(str(filepath))
new_defaults = dict()
for key in ['project', 'train', 'data']:
defaults = config[key]
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
if new_defaults['debug']:
new_defaults.update(
max_epochs=2,
max_steps=2 # The seems to be the new "fast_dev_run"
)
args, _ = parser.parse_known_args()
overrides = overrides or dict()
default_data = overrides.get('data_name', None) or new_defaults['data_name']
default_model = overrides.get('model_name', None) or new_defaults['model_name']
data_name = args.__dict__.get('data_name', None) or default_data
model_name = args.__dict__.get('model_name', None) or default_model
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
found_data_class = locate_and_import_class(data_name, 'datasets')
found_model_class = locate_and_import_class(model_name, 'models')
for module in [Logger, Trainer, found_data_class, found_model_class]:
parser = module.add_argparse_args(parser)
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
args = vars(args)
args.update({key: auto_cast(val) for key, val in args.items()})
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
row_log_interval=1000, # TODO: Better Value / Setting
log_save_interval=10000, # TODO: Better Value / Setting
auto_lr_find=not args['debug'],
weights_summary='top',
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1)
)
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
args.update(**overrides)
return args, found_data_class, found_model_class
def is_jsonable(x):