181 lines
7.4 KiB
Python

import pickle
from argparse import Namespace
from pathlib import Path
from typing import Union
import optuna as optuna
from natsort import natsorted
from optuna.integration import PyTorchLightningPruningCallback
from main import run_lightning_loop
from ml_lib.utils.config import parse_comandline_args_add_defaults
class ContiniousSavingCallback:
@property
def study(self):
return self._study
@property
def tmp_study_path(self):
return Path(self.root) / f'TMP_{self.study.study_name}_trial_{self.study.trials[-1].number}.pkl'
@property
def final_study_path(self):
return Path(self.root) / f'FINAL_{self.study.study_name}_' \
f'best_{self.study.best_trial.number}_' \
f'score_{self.study.best_value}.pkl'
def __init__(self, root: Union[str, Path], study: optuna.Study):
self._study = study
self.root = Path(root)
pass
@staticmethod
def _write_to_disk(object, path):
path = Path(path)
path.parent.mkdir(exist_ok=True)
if path.exists():
path.unlink(missing_ok=True)
with path.open(mode='wb') as f:
pickle.dump(object, f)
def save_final(self):
self._write_to_disk(self.study, self.final_study_path)
def clean_up(self):
temp_study_files = self.root.glob(f'TMP_{self.study.study_name}*')
for temp_study_file in temp_study_files:
temp_study_file.unlink(missing_ok=True)
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial) -> None:
self._write_to_disk(study, self.tmp_study_path)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.save_final()
self.clean_up()
class Objective(object):
def __init__(self, model_class_name, data_class_name, max_epochs, loss):
self.study_params = dict(model_name=model_class_name,
data_name=data_class_name,
max_epochs=max_epochs,
loss=loss,
)
def __call__(self, trial):
# Optuna configuration
folder = Path('study')
folder.mkdir(parents=False, exist_ok=True)
# Suggested Parameters:
scheduler = trial.suggest_categorical('scheduler', [None, 'LambdaLR'])
if scheduler is not None:
lr_scheduler_parameter = trial.suggest_float('lr_scheduler_parameter', 0.8, 1, step=0.01)
else:
lr_scheduler_parameter = None
optuna_suggestions = dict(
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
shift_ratio=trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
noise_ratio=trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
mask_ratio=trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
lr=trial.suggest_loguniform('lr', 1e-5, 1e-3),
dropout=trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
lat_dim=2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
scheduler=scheduler,
lr_scheduler_parameter=lr_scheduler_parameter,
sampler=trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
)
# User defined Parameters:
for params_name in self.study_params.keys():
trial.set_user_attr(params_name, self.study_params[params_name])
trial.set_user_attr('study_name', trial.study.study_name)
optuna_suggestions.update(**trial.user_attrs)
if optuna_suggestions['model_name'] in ['CNNBaseline', 'BandwiseConvClassifier']:
model_depth = trial.suggest_int('model_depth', 1, 6, step=1)
filters = list()
for layer_idx in range(model_depth):
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
optuna_suggestions.update(filters=filters)
elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']:
transformer_dict = dict(
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
patch_size=trial.suggest_int('patch_size', 6, 20, step=3),
attn_depth=trial.suggest_int('attn_depth', 2, 20, step=4),
heads=trial.suggest_int('heads', 2, 16, step=2),
embedding_size=trial.suggest_int('embedding_size', 12, 64, step=12)
)
optuna_suggestions.update(**transformer_dict)
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="PL_recall_score")
# Parse comandline args, read config and get model
h_params, found_data_class, found_model_class, seed = parse_comandline_args_add_defaults(
'_parameters.ini', overrides=optuna_suggestions)
h_params = Namespace(**h_params)
results = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
additional_callbacks=pruning_callback, seed=seed)
best_score = results.best_model_score
return best_score
if __name__ == '__main__':
# Study Parameters
out_folder = Path('study')
model_name = 'CNNBaseline'
data_name = 'Urban8KLibrosaDatamodule'
loss = 'ce_loss'
max_epochs = 200
n_trials = 400
study_name = f'{model_name}_{max_epochs}_{data_name}'
# Create Study or load study:
try:
found_studys = [x for x in out_folder.iterdir() if study_name in x.name]
except FileNotFoundError:
found_studys = []
if found_studys:
latest_found_study = natsorted(found_studys, key=lambda x: x.stem[x.stem.find('_trial'):])[-1]
with latest_found_study.open('rb') as latest_found_study_file:
optuna_study = pickle.load(latest_found_study_file)
n_trials = n_trials - len(optuna_study.trials)
print(f'An old study has been found and loaded: {optuna_study.study_name}')
else:
print(f'A new Study will be created: {study_name}')
optuna_study = optuna.create_study(study_name=study_name,
direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
n_trials = n_trials
# Optimize it
with ContiniousSavingCallback(out_folder, optuna_study) as continious_save_callback:
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
optuna_study.optimize(Objective(model_name, data_name, max_epochs, loss), n_trials=n_trials,
show_progress_bar=True,
callbacks=[continious_save_callback], catch=(Exception, ))
print("Number of finished trials: {}".format(len(optuna_study.trials)))
print("Best trial:")
trial = optuna_study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
exit()