60 lines
3.0 KiB
Python
60 lines
3.0 KiB
Python
import shutil
|
|
import warnings
|
|
|
|
from util.config import MConfig
|
|
|
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
|
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
|
# Imports
|
|
# =============================================================================
|
|
|
|
from main import run_lightning_loop
|
|
from _paramters import main_arg_parser
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
args = main_arg_parser.parse_args()
|
|
# Model Settings
|
|
config = MConfig().read_namespace(args)
|
|
|
|
arg_dict = dict()
|
|
for seed in range(0, 10):
|
|
arg_dict.update(main_seed=seed)
|
|
for model in ['CC']: # , 'BCMC', 'BCC', 'RCC']:
|
|
arg_dict.update(model_type=model)
|
|
raw_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
|
data_stretch=False, train_epochs=101)
|
|
all_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.2,
|
|
data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
|
data_stretch=True, train_epochs=51)
|
|
speed_conf = dict(data_speed_factor=0.7, data_speed_ratio=0.2, data_mask_ratio=0.0,
|
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
|
data_stretch=True, train_epochs=51)
|
|
mask_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.2,
|
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
|
data_stretch=True, train_epochs=51)
|
|
noise_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
|
data_noise_ratio=0.4, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
|
data_stretch=True, train_epochs=51)
|
|
shift_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
|
data_noise_ratio=0.0, data_shift_ratio=0.4, data_loudness_ratio=0.0,
|
|
data_stretch=True, train_epochs=51)
|
|
loudness_conf = dict(data_speed_factor=0.0, data_speed_ratio=0.0, data_mask_ratio=0.0,
|
|
data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.4,
|
|
data_stretch=True, train_epochs=51)
|
|
|
|
for dicts in [raw_conf, all_conf, speed_conf, mask_conf,noise_conf, shift_conf, loudness_conf]:
|
|
|
|
arg_dict.update(dicts)
|
|
config = config.update(arg_dict)
|
|
version_path = config.exp_path / config.version
|
|
if version_path.exists():
|
|
if not (version_path / 'weights.ckpt').exists():
|
|
shutil.rmtree(version_path)
|
|
else:
|
|
continue
|
|
run_lightning_loop(config)
|