# Imports # ============================================================================= from _paramters import main_arg_parser from main import run_lightning_loop import warnings import shutil from ml_lib.utils.config import Config warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) if __name__ == '__main__': args = main_arg_parser.parse_args() # Model Settings config = Config().read_namespace(args) arg_dict = dict() for seed in range(1): arg_dict.update(main_seed=seed) if False: for patch_size in [3, 5 , 9]: for model in ['VerticalVisualTransformer']: arg_dict.update(model_type=model, model_patch_size=patch_size) raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_stretch=False, train_epochs=401) all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) speed_conf = raw_conf.copy() speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7, data_stretch=True, train_epochs=101) mask_conf = raw_conf.copy() mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101) noise_conf = raw_conf.copy() noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101) shift_conf = raw_conf.copy() shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101) loudness_conf = raw_conf.copy() loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: arg_dict.update(dicts) if True: for patch_size in [7]: for lat_dim in [32]: for heads in [8]: for embedding_size in [7**2]: for attn_depth in [1, 3, 5, 7]: for model in ['HorizontalVisualTransformer']: arg_dict.update( model_type=model, model_patch_size=patch_size, model_lat_dim=lat_dim, model_heads=heads, model_embedding_size=embedding_size, model_attn_depth=attn_depth ) config = config.update(arg_dict) version_path = config.exp_path / config.version if version_path.exists(): if not (version_path / 'weights.ckpt').exists(): shutil.rmtree(version_path) else: continue run_lightning_loop(config)