2021-04-02 08:45:11 +02:00

142 lines
5.6 KiB
Python

import pickle
from argparse import Namespace
from pathlib import Path
from typing import Union
import optuna as optuna
from optuna.integration import PyTorchLightningPruningCallback
from main import run_lightning_loop
from ml_lib.utils.config import parse_comandline_args_add_defaults
class ContiniousSavingCallback:
@property
def study(self):
return self._study
@property
def tmp_study_path(self):
return Path(self.root) / f'TMP_{self.study.study_name}_trial{self.study.trials[-1].number}.pkl'
@property
def final_study_path(self):
return Path(self.root) / f'FINAL_{self.study.study_name}_' \
f'best_{self.study.best_trial.number}_' \
f'score_{self.study.best_value}.pkl'
def __init__(self, root:Union[str, Path], study: optuna.Study):
self._study = study
self.root = Path(root)
pass
@staticmethod
def _write_to_disk(object, path):
path = Path(path)
path.parent.mkdir(exist_ok=True)
if path.exists():
path.unlink(missing_ok=True)
with path.open(mode='wb') as f:
pickle.dump(object, f)
def save_final(self):
self._write_to_disk(self.study, self.final_study_path())
def clean_up(self):
temp_study_files = self.root.glob(f'TMP_{self.study.study_name}*')
for temp_study_file in temp_study_files:
temp_study_file.unlink(missing_ok=True)
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
self._write_to_disk(study, self.tmp_study_path(trial.number))
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.save_final()
self.clean_up()
def optimize(trial: optuna.Trial):
# Optuna configuration
folder = Path('study')
folder.mkdir(parents=False, exist_ok=True)
scheduler = trial.suggest_categorical('scheduler', [None, 'LambdaLR'])
if scheduler is not None:
lr_scheduler_parameter = trial.suggest_float('lr_scheduler_parameter', 0.8, 1, step=0.01)
else:
lr_scheduler_parameter = None
optuna_suggestions = dict(
model_name='VisualTransformer',
data_name='CCSLibrosaDatamodule',
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
max_epochs=200,
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
shift_ratio=trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
noise_ratio=trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
mask_ratio=trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
lr=trial.suggest_loguniform('lr', 1e-5, 1e-3),
dropout=trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
lat_dim=2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
scheduler=scheduler,
lr_scheduler_parameter=lr_scheduler_parameter,
loss='ce_loss',
sampler=trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
study_name=trial.study.study_name
)
if optuna_suggestions['model_name'] == 'CNNBaseline':
model_depth = trial.suggest_int('model_depth', 1, 6, step=1)
filters = list()
for layer_idx in range(model_depth):
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
optuna_suggestions.update(filters=filters)
elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']:
transformer_dict = dict(
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
patch_size=trial.suggest_int('patch_size', 6, 12, step=3),
attn_depth=trial.suggest_int('attn_depth', 2, 14, step=4),
heads=trial.suggest_int('heads', 2, 16, step=2),
embedding_size=trial.suggest_int('embedding_size', 12, 64, step=12)
)
optuna_suggestions.update(**transformer_dict)
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="PL_recall_score")
# Parse comandline args, read config and get model
h_params, found_data_class, found_model_class, seed = parse_comandline_args_add_defaults(
'_parameters.ini', overrides=optuna_suggestions)
h_params = Namespace(**h_params)
try:
results = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
additional_callbacks=pruning_callback, seed=seed)
best_score = results.best_model_score
except Exception as e:
print(e)
best_score = 0
return best_score
if __name__ == '__main__':
optuna_study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
with ContiniousSavingCallback('study', optuna_study) as continious_save_callback:
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
optuna_study.optimize(optimize, n_trials=200, show_progress_bar=True, callbacks=[continious_save_callback])
print("Number of finished trials: {}".format(len(optuna_study.trials)))
print("Best trial:")
trial = optuna_study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))