Experiment parameters
This commit is contained in:
parent
a079a196af
commit
98ccecac04
90
multi_run.py
90
multi_run.py
@ -1,4 +1,8 @@
|
|||||||
import shutil
|
# Imports
|
||||||
|
# =============================================================================
|
||||||
|
from _paramters import main_arg_parser
|
||||||
|
from main import run_lightning_loop
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ml_lib.utils.config import Config
|
from ml_lib.utils.config import Config
|
||||||
@ -6,12 +10,6 @@ from ml_lib.utils.config import Config
|
|||||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
# Imports
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
from main import run_lightning_loop
|
|
||||||
from _paramters import main_arg_parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
@ -20,43 +18,59 @@ if __name__ == '__main__':
|
|||||||
config = Config().read_namespace(args)
|
config = Config().read_namespace(args)
|
||||||
|
|
||||||
arg_dict = dict()
|
arg_dict = dict()
|
||||||
for seed in range(0, 3):
|
for seed in range(1):
|
||||||
arg_dict.update(main_seed=seed)
|
arg_dict.update(main_seed=seed)
|
||||||
for patch_size in [3, 5 , 9]:
|
if False:
|
||||||
for model in ['SequentialVisualTransformer']:
|
for patch_size in [3, 5 , 9]:
|
||||||
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
for model in ['SequentialVisualTransformer']:
|
||||||
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
||||||
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
raw_conf = dict(data_speed_amount=0.0, data_speed_min=0.0, data_speed_max=0.0,
|
||||||
data_stretch=False, train_epochs=401)
|
data_mask_ratio=0.0, data_noise_ratio=0.0, data_shift_ratio=0.0, data_loudness_ratio=0.0,
|
||||||
|
data_stretch=False, train_epochs=401)
|
||||||
|
|
||||||
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
all_conf = dict(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
||||||
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
data_mask_ratio=0.2, data_noise_ratio=0.4, data_shift_ratio=0.4, data_loudness_ratio=0.4,
|
||||||
data_stretch=True, train_epochs=101)
|
data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
speed_conf = raw_conf.copy()
|
speed_conf = raw_conf.copy()
|
||||||
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
speed_conf.update(data_speed_amount=0.4, data_speed_min=0.7, data_speed_max=1.7,
|
||||||
data_stretch=True, train_epochs=101)
|
data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
mask_conf = raw_conf.copy()
|
mask_conf = raw_conf.copy()
|
||||||
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
|
mask_conf.update(data_mask_ratio=0.2, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
noise_conf = raw_conf.copy()
|
noise_conf = raw_conf.copy()
|
||||||
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
|
noise_conf.update(data_noise_ratio=0.4, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
shift_conf = raw_conf.copy()
|
shift_conf = raw_conf.copy()
|
||||||
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
|
shift_conf.update(data_shift_ratio=0.4, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
loudness_conf = raw_conf.copy()
|
loudness_conf = raw_conf.copy()
|
||||||
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
|
loudness_conf.update(data_loudness_ratio=0.4, data_stretch=True, train_epochs=101)
|
||||||
|
|
||||||
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
|
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
|
||||||
|
|
||||||
arg_dict.update(dicts)
|
arg_dict.update(dicts)
|
||||||
config = config.update(arg_dict)
|
if True:
|
||||||
version_path = config.exp_path / config.version
|
for patch_size in [3, 7]:
|
||||||
if version_path.exists():
|
for lat_dim in [4, 32]:
|
||||||
if not (version_path / 'weights.ckpt').exists():
|
for heads in [2, 4]:
|
||||||
shutil.rmtree(version_path)
|
for embedding_size in [32, 64]:
|
||||||
else:
|
for attn_depth in [1, 3]:
|
||||||
continue
|
for model in ['SequentialVisualTransformer', 'VisualTransformer']:
|
||||||
run_lightning_loop(config)
|
arg_dict.update(
|
||||||
|
model_type=model,
|
||||||
|
model_patch_size=patch_size,
|
||||||
|
model_lat_dim=lat_dim,
|
||||||
|
model_heads=heads,
|
||||||
|
model_embedding_size=embedding_size,
|
||||||
|
model_attn_depth=attn_depth
|
||||||
|
)
|
||||||
|
config = config.update(arg_dict)
|
||||||
|
version_path = config.exp_path / config.version
|
||||||
|
if version_path.exists():
|
||||||
|
if not (version_path / 'weights.ckpt').exists():
|
||||||
|
shutil.rmtree(version_path)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
run_lightning_loop(config)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user