2021-03-04 12:01:09 +01:00

49 lines
1.8 KiB
Python

from argparse import Namespace
from tqdm import tqdm
from main import run_lightning_loop
from ml_lib.utils.config import parse_comandline_args_add_defaults
import itertools
if __name__ == '__main__':
# Set new values
hparams_dict = dict(model_name=['VisualTransformer'],
max_epochs=[150],
batch_size=[50],
random_apply_chance=[0.5],
loudness_ratio=[0],
shift_ratio=[0.3],
noise_ratio=[0.3],
mask_ratio=[0.3],
lr=[0.001],
dropout=[0.2],
lat_dim=[32, 64],
patch_size=[8, 12],
attn_depth=[12],
heads=[6],
embedding_size=[16, 32],
loss=['ce_loss'],
sampler=['WeightedRandomSampler']
)
keys, values = zip(*hparams_dict.items())
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults(
'_parameters.ini', overrides=permutations_dict)
hparams = dict(**cmd_args)
hparams.update(permutations_dict)
hparams = Namespace(**hparams)
# RUN
# ---------------------------------------
print(f'Running Loop, parameters are: {permutations_dict}')
run_lightning_loop(hparams, found_data_class, found_model_class)
print(f'Done, parameters were: {permutations_dict}')
pass