import shutil import warnings from ml_lib.utils.config import Config warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) # Imports # ============================================================================= from main import run_lightning_loop from _paramters import main_arg_parser if __name__ == '__main__': args = main_arg_parser.parse_args() # Model Settings config = Config().read_namespace(args) arg_dict = dict() for seed in range(0, 10): arg_dict.update(main_seed=seed) for model in ['VisualTransformer']: arg_dict.update(model_type=model) raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_stretch=False, train_epochs=401) all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) speed_conf = raw_conf.copy() speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, data_stretch=True, train_epochs=101) mask_conf = raw_conf.copy() mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) noise_conf = raw_conf.copy() noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) shift_conf = raw_conf.copy() shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) loudness_conf = raw_conf.copy() loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: arg_dict.update(dicts) config = config.update(arg_dict) version_path = config.exp_path / config.version if version_path.exists(): if not (version_path / 'weights.ckpt').exists(): shutil.rmtree(version_path) else: continue run_lightning_loop(config)