Transformer running
This commit is contained in:
48
multi_run.py
Normal file
48
multi_run.py
Normal file
@ -0,0 +1,48 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from main import run_lightning_loop
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
|
||||
import itertools
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Set new values
|
||||
hparams_dict = dict(model_name=['VisualTransformer'],
|
||||
max_epochs=[150],
|
||||
batch_size=[50],
|
||||
random_apply_chance=[0.5],
|
||||
loudness_ratio=[0],
|
||||
shift_ratio=[0.3],
|
||||
noise_ratio=[0.3],
|
||||
mask_ratio=[0.3],
|
||||
lr=[0.001],
|
||||
dropout=[0.2],
|
||||
lat_dim=[32, 64],
|
||||
patch_size=[8, 12],
|
||||
attn_depth=[12],
|
||||
heads=[6],
|
||||
embedding_size=[16, 32],
|
||||
loss=['ce_loss'],
|
||||
sampler=['WeightedRandomSampler']
|
||||
)
|
||||
|
||||
keys, values = zip(*hparams_dict.items())
|
||||
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
||||
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
|
||||
# Parse comandline args, read config and get model
|
||||
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults(
|
||||
'_parameters.ini', overrides=permutations_dict)
|
||||
|
||||
hparams = dict(**cmd_args)
|
||||
hparams.update(permutations_dict)
|
||||
hparams = Namespace(**hparams)
|
||||
|
||||
# RUN
|
||||
# ---------------------------------------
|
||||
print(f'Running Loop, parameters are: {permutations_dict}')
|
||||
run_lightning_loop(hparams, found_data_class, found_model_class)
|
||||
print(f'Done, parameters were: {permutations_dict}')
|
||||
pass
|
Reference in New Issue
Block a user