import shutil import warnings from util.config import MConfig warnings.filterwarnings('ignore', category=FutureWarning) warnings.filterwarnings('ignore', category=UserWarning) # Imports # ============================================================================= from main import run_lightning_loop from _paramters import main_arg_parser if __name__ == '__main__': args = main_arg_parser.parse_args() # Model Settings config = MConfig().read_namespace(args) arg_dict = dict() for seed in range(0, 10): arg_dict.update(main_seed=seed) for model in ['CC', 'BCMC', 'BCC', 'RCC']: arg_dict.update(model_type=model) raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_stretch=False, train_epochs=401) all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.4, data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.4, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_stretch=True, train_epochs=101) mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_stretch=True, train_epochs=101) noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0, data_stretch=True, train_epochs=101) shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0, data_stretch=True, train_epochs=101) loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4, data_stretch=True, train_epochs=101) for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]: arg_dict.update(dicts) config = config.update(arg_dict) version_path = config.exp_path / config.version if version_path.exists(): if not (version_path / 'weights.ckpt').exists(): shutil.rmtree(version_path) else: continue run_lightning_loop(config)