2020-05-21 14:42:35 +02:00

59 lines
2.4 KiB
Python

import shutil
import warnings
from util.config import MConfig
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from main import run_lightning_loop
from _paramters import main_arg_parser
if __name__ == '__main__':
args = main_arg_parser.parse_args()
# Model Settings
config = MConfig().read_namespace(args)
arg_dict = dict()
for seed in range(0, 10):
arg_dict.update(main_seed=seed)
for model in ['CC', 'BCMC', 'BCC', 'RCC']:
arg_dict.update(model_type=model)
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
data_stretch=False, train_epochs=401)
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.4, data_mask_ratio=0.2,
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
data_stretch=True, train_epochs=101)
speed_conf = raw_conf.copy()
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
data_stretch=True, train_epochs=101)
mask_conf = raw_conf.copy()
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
noise_conf = raw_conf.copy()
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
shift_conf = raw_conf.copy()
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
loudness_conf = raw_conf.copy()
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
arg_dict.update(dicts)
config = config.update(arg_dict)
version_path = config.exp_path / config.version
if version_path.exists():
if not (version_path / 'weights.ckpt').exists():
shutil.rmtree(version_path)
else:
continue
run_lightning_loop(config)