Transformer running
This commit is contained in:
@ -1,6 +1,9 @@
|
||||
import ast
|
||||
import configparser
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Dict
|
||||
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
|
||||
from abc import ABC
|
||||
@ -9,8 +12,67 @@ from argparse import Namespace, ArgumentParser
|
||||
from collections import defaultdict
|
||||
from configparser import ConfigParser, DuplicateSectionError
|
||||
import hashlib
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from ml_lib.utils.tools import locate_and_import_class
|
||||
from ml_lib.utils.loggers import Logger
|
||||
from ml_lib.utils.tools import locate_and_import_class, auto_cast
|
||||
|
||||
|
||||
# Argument Parser and default Values
|
||||
# =============================================================================
|
||||
def parse_comandline_args_add_defaults(filepath, overrides=None):
|
||||
|
||||
# Parse Command Line
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--model_name', type=str)
|
||||
parser.add_argument('--data_name', type=str)
|
||||
|
||||
# Load Defaults from _parameters.ini file
|
||||
config = configparser.ConfigParser()
|
||||
config.read(str(filepath))
|
||||
|
||||
new_defaults = dict()
|
||||
for key in ['project', 'train', 'data']:
|
||||
defaults = config[key]
|
||||
new_defaults.update({key: auto_cast(val) for key, val in defaults.items()})
|
||||
|
||||
if new_defaults['debug']:
|
||||
new_defaults.update(
|
||||
max_epochs=2,
|
||||
max_steps=2 # The seems to be the new "fast_dev_run"
|
||||
)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
overrides = overrides or dict()
|
||||
default_data = overrides.get('data_name', None) or new_defaults['data_name']
|
||||
default_model = overrides.get('model_name', None) or new_defaults['model_name']
|
||||
|
||||
data_name = args.__dict__.get('data_name', None) or default_data
|
||||
model_name = args.__dict__.get('model_name', None) or default_model
|
||||
|
||||
new_defaults.update({key: auto_cast(val) for key, val in config[model_name].items()})
|
||||
|
||||
found_data_class = locate_and_import_class(data_name, 'datasets')
|
||||
found_model_class = locate_and_import_class(model_name, 'models')
|
||||
|
||||
for module in [Logger, Trainer, found_data_class, found_model_class]:
|
||||
parser = module.add_argparse_args(parser)
|
||||
|
||||
args, _ = parser.parse_known_args(namespace=Namespace(**new_defaults))
|
||||
|
||||
args = vars(args)
|
||||
args.update({key: auto_cast(val) for key, val in args.items()})
|
||||
args.update(gpus=[0] if torch.cuda.is_available() and not args['debug'] else None,
|
||||
row_log_interval=1000, # TODO: Better Value / Setting
|
||||
log_save_interval=10000, # TODO: Better Value / Setting
|
||||
auto_lr_find=not args['debug'],
|
||||
weights_summary='top',
|
||||
check_val_every_n_epoch=1 if args['debug'] else args.get('check_val_every_n_epoch', 1)
|
||||
)
|
||||
|
||||
if overrides is not None and isinstance(overrides, (Mapping, Dict)):
|
||||
args.update(**overrides)
|
||||
return args, found_data_class, found_model_class
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
|
Reference in New Issue
Block a user