optuna tune

This commit is contained in:
Steffen 2021-04-03 18:39:29 +02:00
parent cec3a07d60
commit 3955f5ccd0
3 changed files with 95 additions and 2530 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -26,7 +26,7 @@ class ContiniousSavingCallback:
f'best_{self.study.best_trial.number}_' \
f'score_{self.study.best_value}.pkl'
def __init__(self, root:Union[str, Path], study: optuna.Study):
def __init__(self, root: Union[str, Path], study: optuna.Study):
self._study = study
self.root = Path(root)
pass
@ -49,7 +49,7 @@ class ContiniousSavingCallback:
temp_study_file.unlink(missing_ok=True)
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
self._write_to_disk(study, self.tmp_study_path(trial.number))
self._write_to_disk(study, self.tmp_study_path)
def __enter__(self):
return self
@ -70,10 +70,10 @@ def optimize(trial: optuna.Trial):
lr_scheduler_parameter = None
optuna_suggestions = dict(
model_name='VisualTransformer',
data_name='CCSLibrosaDatamodule',
model_name='CNNBaseline',
data_name='PrimatesLibrosaDatamodule',
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
max_epochs=200,
max_epochs=400,
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
@ -99,8 +99,8 @@ def optimize(trial: optuna.Trial):
transformer_dict = dict(
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
patch_size=trial.suggest_int('patch_size', 6, 12, step=3),
attn_depth=trial.suggest_int('attn_depth', 2, 14, step=4),
patch_size=trial.suggest_int('patch_size', 6, 20, step=3),
attn_depth=trial.suggest_int('attn_depth', 2, 20, step=4),
heads=trial.suggest_int('heads', 2, 16, step=2),
embedding_size=trial.suggest_int('embedding_size', 12, 64, step=12)
)