import pickle from argparse import Namespace from pathlib import Path from typing import Union import optuna as optuna from optuna.integration import PyTorchLightningPruningCallback from main import run_lightning_loop from ml_lib.utils.config import parse_comandline_args_add_defaults class ContiniousSavingCallback: @property def study(self): return self._study @property def tmp_study_path(self): return Path(self.root) / f'TMP_{self.study.study_name}_trial{self.study.trials[-1].number}.pkl' @property def final_study_path(self): return Path(self.root) / f'FINAL_{self.study.study_name}_' \ f'best_{self.study.best_trial.number}_' \ f'score_{self.study.best_value}.pkl' def __init__(self, root:Union[str, Path], study: optuna.Study): self._study = study self.root = Path(root) pass @staticmethod def _write_to_disk(object, path): path = Path(path) path.parent.mkdir(exist_ok=True) if path.exists(): path.unlink(missing_ok=True) with path.open(mode='wb') as f: pickle.dump(object, f) def save_final(self): self._write_to_disk(self.study, self.final_study_path()) def clean_up(self): temp_study_files = self.root.glob(f'TMP_{self.study.study_name}*') for temp_study_file in temp_study_files: temp_study_file.unlink(missing_ok=True) def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None: self._write_to_disk(study, self.tmp_study_path(trial.number)) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.save_final() self.clean_up() def optimize(trial: optuna.Trial): # Optuna configuration folder = Path('study') folder.mkdir(parents=False, exist_ok=True) scheduler = trial.suggest_categorical('scheduler', [None, 'LambdaLR']) if scheduler is not None: lr_scheduler_parameter = trial.suggest_float('lr_scheduler_parameter', 0.8, 1, step=0.01) else: lr_scheduler_parameter = None optuna_suggestions = dict( model_name='VisualTransformer', data_name='CCSLibrosaDatamodule', batch_size=trial.suggest_int('batch_size', 5, 50, step=5), max_epochs=200, target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1), random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1), loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1), shift_ratio=trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1), noise_ratio=trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1), mask_ratio=trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1), lr=trial.suggest_loguniform('lr', 1e-5, 1e-3), dropout=trial.suggest_float('dropout', 0.0, 0.3, step=0.05), lat_dim=2 ** trial.suggest_int('lat_dim', 1, 5, step=1), scheduler=scheduler, lr_scheduler_parameter=lr_scheduler_parameter, loss='ce_loss', sampler=trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']), study_name=trial.study.study_name ) if optuna_suggestions['model_name'] == 'CNNBaseline': model_depth = trial.suggest_int('model_depth', 1, 6, step=1) filters = list() for layer_idx in range(model_depth): filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1)) optuna_suggestions.update(filters=filters) elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']: transformer_dict = dict( mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1), head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1), patch_size=trial.suggest_int('patch_size', 6, 12, step=3), attn_depth=trial.suggest_int('attn_depth', 2, 14, step=4), heads=trial.suggest_int('heads', 2, 16, step=2), embedding_size=trial.suggest_int('embedding_size', 12, 64, step=12) ) optuna_suggestions.update(**transformer_dict) pruning_callback = PyTorchLightningPruningCallback(trial, monitor="PL_recall_score") # Parse comandline args, read config and get model h_params, found_data_class, found_model_class, seed = parse_comandline_args_add_defaults( '_parameters.ini', overrides=optuna_suggestions) h_params = Namespace(**h_params) try: results = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class, additional_callbacks=pruning_callback, seed=seed) best_score = results.best_model_score except Exception as e: print(e) best_score = 0 return best_score if __name__ == '__main__': optuna_study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337)) with ContiniousSavingCallback('study', optuna_study) as continious_save_callback: # study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)]) optuna_study.optimize(optimize, n_trials=200, show_progress_bar=True, callbacks=[continious_save_callback]) print("Number of finished trials: {}".format(len(optuna_study.trials))) print("Best trial:") trial = optuna_study.best_trial print(" Value: {}".format(trial.value)) print(" Params: ") for key, value in trial.params.items(): print(" {}: {}".format(key, value))