diff --git a/multi_run.py b/multi_run.py index bce16fd..5e996cf 100644 --- a/multi_run.py +++ b/multi_run.py @@ -1,4 +1,8 @@ -import shutil +# Imports +# ============================================================================= +from _paramters import main_arg_parser +from main import run_lightning_loop + import warnings from ml_lib.utils.config import Config @@ -6,12 +10,6 @@ from ml_lib.utils.config import Config warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) -# Imports -# ============================================================================= - -from main import run_lightning_loop -from _paramters import main_arg_parser - if __name__ == '__main__': @@ -20,43 +18,59 @@ if __name__ == '__main__': config = Config().read_namespace(args) arg_dict = dict() - for seed in range(0, 3): + for seed in range(1): arg_dict.update(main_seed=seed) - for patch_size in [3, 5 , 9]: - for model in ['SequentialVisualTransformer']: - arg_dict.update(model_type=model, model_patch_size=patch_size) - raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, - data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, - data_stretch=False, train_epochs=401) + if False: + for patch_size in [3, 5 , 9]: + for model in ['SequentialVisualTransformer']: + arg_dict.update(model_type=model, model_patch_size=patch_size) + raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, + data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, + data_stretch=False, train_epochs=401) - all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, - data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, - data_stretch=True, train_epochs=101) + all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, + data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, + data_stretch=True, train_epochs=101) - speed_conf = raw_conf.copy() - speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, - data_stretch=True, train_epochs=101) + speed_conf = raw_conf.copy() + speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, + data_stretch=True, train_epochs=101) - mask_conf = raw_conf.copy() - mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) + mask_conf = raw_conf.copy() + mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) - noise_conf = raw_conf.copy() - noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) + noise_conf = raw_conf.copy() + noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) - shift_conf = raw_conf.copy() - shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) + shift_conf = raw_conf.copy() + shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) - loudness_conf = raw_conf.copy() - loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) + loudness_conf = raw_conf.copy() + loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) - for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: + for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: - arg_dict.update(dicts) - config = config.update(arg_dict) - version_path = config.exp_path / config.version - if version_path.exists(): - if not (version_path / 'weights.ckpt').exists(): - shutil.rmtree(version_path) - else: - continue - run_lightning_loop(config) + arg_dict.update(dicts) + if True: + for patch_size in [3, 7]: + for lat_dim in [4, 32]: + for heads in [2, 4]: + for embedding_size in [32, 64]: + for attn_depth in [1, 3]: + for model in ['SequentialVisualTransformer', 'VisualTransformer']: + arg_dict.update( + model_type=model, + model_patch_size=patch_size, + model_lat_dim=lat_dim, + model_heads=heads, + model_embedding_size=embedding_size, + model_attn_depth=attn_depth + ) + config = config.update(arg_dict) + version_path = config.exp_path / config.version + if version_path.exists(): + if not (version_path / 'weights.ckpt').exists(): + shutil.rmtree(version_path) + else: + continue + run_lightning_loop(config)