78 lines
4.0 KiB
Python

from argparse import Namespace
from tqdm import tqdm
from main import run_lightning_loop
from ml_lib.utils.config import parse_comandline_args_add_defaults
import itertools
if __name__ == '__main__':
# Set new values
hparams_dict = dict(seed=range(13, 20),
# BandwiseConvClassifier, CNNBaseline, VisualTransformer, VerticalVisualTransformer
model_name=['BandwiseConvClassifier'],
# CCSLibrosaDatamodule, PrimatesLibrosaDatamodule,
data_name=['PrimatesLibrosaDatamodule'],
batch_size=[20],
max_epochs=[200],
target_mel_length_in_seconds=[0.4],
outpath=['optuna_found_param_run'],
dropout=[0.0], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
scheduler=[None], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
lr_scheduler_parameter=[None], # [0.95],
loss=['ce_loss'],
sampler=['WeightedRandomSampler'],
# trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
weight_decay=[0], # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
)
# Data Aug Parameters
hparams_dict.update(random_apply_chance=[0.1], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
loudness_ratio=[0.2], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
noise_ratio=[0.4], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
mask_ratio=[0.3], # triaSl.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),)
)
if True:
# CNN Parameters:
hparams_dict.update(filters=[[6, 6, 6]],
lr=[0.0003414550170649836], # trial.suggest_uniform('lr', 1e-3, 3e-3),
variable_length=[False], # THIS does not Work
lat_dim=[2 ** 3], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
)
else:
# Transfornmer Parameters:
hparams_dict.update(lr=[0.0008292481039683588], # trial.suggest_uniform('lr', 1e-3, 3e-3),
lat_dim=[2**4], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
mlp_dim=[2**4],
head_dim=[2**4], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
patch_size=[6], # trial.suggest_int('patch_size', 6, 12, step=3),
attn_depth=[10], # trial.suggest_int('attn_depth', 2, 14, step=4),
heads=[16], # trial.suggest_int('heads', 2, 16, step=2),
embedding_size=[60], # trial.suggest_int('embedding_size', 12, 64, step=12),
variable_length=[False], # THIS does not Work
)
keys, values = zip(*hparams_dict.items())
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
# Parse comandline args, read config and get model
cmd_args, found_data_class, found_model_class, found_seed = parse_comandline_args_add_defaults(
'_parameters.ini', overrides=permutations_dict)
hparams = dict(**cmd_args)
hparams.update(permutations_dict)
hparams = Namespace(**hparams)
# RUN
# ---------------------------------------
print(f'Running Loop, parameters are: {permutations_dict}')
run_lightning_loop(hparams, found_data_class, found_model_class, seed=found_seed)
print(f'Done, parameters were: {permutations_dict}')
pass