Experiment parameters
This commit is contained in:
parent
a079a196af
commit
98ccecac04
30
multi_run.py
30
multi_run.py
@ -1,4 +1,8 @@
|
||||
import shutil
|
||||
# Imports
|
||||
# =============================================================================
|
||||
from _paramters import main_arg_parser
|
||||
from main import run_lightning_loop
|
||||
|
||||
import warnings
|
||||
|
||||
from ml_lib.utils.config import Config
|
||||
@ -6,12 +10,6 @@ from ml_lib.utils.config import Config
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
# Imports
|
||||
# =============================================================================
|
||||
|
||||
from main import run_lightning_loop
|
||||
from _paramters import main_arg_parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -20,8 +18,9 @@ if __name__ == '__main__':
|
||||
config = Config().read_namespace(args)
|
||||
|
||||
arg_dict = dict()
|
||||
for seed in range(0, 3):
|
||||
for seed in range(1):
|
||||
arg_dict.update(main_seed=seed)
|
||||
if False:
|
||||
for patch_size in [3, 5 , 9]:
|
||||
for model in ['SequentialVisualTransformer']:
|
||||
arg_dict.update(model_type=model, model_patch_size=patch_size)
|
||||
@ -52,6 +51,21 @@ if __name__ == '__main__':
|
||||
for dicts in [raw_conf, all_conf, speed_conf, mask_conf, noise_conf, shift_conf, loudness_conf]:
|
||||
|
||||
arg_dict.update(dicts)
|
||||
if True:
|
||||
for patch_size in [3, 7]:
|
||||
for lat_dim in [4, 32]:
|
||||
for heads in [2, 4]:
|
||||
for embedding_size in [32, 64]:
|
||||
for attn_depth in [1, 3]:
|
||||
for model in ['SequentialVisualTransformer', 'VisualTransformer']:
|
||||
arg_dict.update(
|
||||
model_type=model,
|
||||
model_patch_size=patch_size,
|
||||
model_lat_dim=lat_dim,
|
||||
model_heads=heads,
|
||||
model_embedding_size=embedding_size,
|
||||
model_attn_depth=attn_depth
|
||||
)
|
||||
config = config.update(arg_dict)
|
||||
version_path = config.exp_path / config.version
|
||||
if version_path.exists():
|
||||
|
Loading…
x
Reference in New Issue
Block a user