80 lines
3.5 KiB
Python
80 lines
3.5 KiB
Python
import pickle
|
|
from argparse import Namespace
|
|
from pathlib import Path
|
|
|
|
import optuna as optuna
|
|
from optuna.integration import PyTorchLightningPruningCallback
|
|
|
|
from main import run_lightning_loop
|
|
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
|
import neptunecontrib.monitoring.optuna as opt_utils
|
|
|
|
|
|
def optimize(trial: optuna.Trial):
|
|
# Optuna configuration
|
|
folder = Path('study')
|
|
folder.mkdir(parents=False, exist_ok=True)
|
|
optuna_suggestions = dict(
|
|
model_name='VisualTransformer',
|
|
batch_size=trial.suggest_int('batch_size', 30, 100, step=32),
|
|
lr_scheduler_parameter=trial.suggest_float('lr_scheduler_parameter', 0.8, 1, step=0.01),
|
|
max_epochs=100,
|
|
random_apply_chance=0.1, # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
|
loudness_ratio=0.1, # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
|
shift_ratio=0.1, # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
|
noise_ratio=0, # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
|
mask_ratio=0.2, # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
|
lr=trial.suggest_uniform('lr', 1e-3, 3e-3),
|
|
dropout=0.05, # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
|
lat_dim=32, # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
|
mlp_dim=16, # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
|
head_dim=8, # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
|
patch_size=12, # trial.suggest_int('patch_size', 6, 12, step=3),
|
|
attn_depth=10, # trial.suggest_int('attn_depth', 2, 14, step=4),
|
|
heads=16, # trial.suggest_int('heads', 2, 16, step=2),
|
|
scheduler='LambdaLR', # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
|
embedding_size=48, # trial.suggest_int('embedding_size', 12, 64, step=12),
|
|
loss='ce_loss',
|
|
sampler='WeightedRandomSampler', # rial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
|
weight_decay=trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
|
|
study_name=trial.study.study_name
|
|
)
|
|
|
|
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="PL_recall_score")
|
|
|
|
# Parse comandline args, read config and get model
|
|
cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults('_parameters.ini')
|
|
|
|
h_params = dict(**cmd_args)
|
|
h_params.update(optuna_suggestions)
|
|
h_params = Namespace(**h_params)
|
|
try:
|
|
best_score = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
|
additional_callbacks=pruning_callback)
|
|
except Exception as e:
|
|
print(e)
|
|
best_score = 0
|
|
return best_score
|
|
|
|
|
|
if __name__ == '__main__':
|
|
study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
|
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
|
study.optimize(optimize, n_trials=50)
|
|
|
|
print("Number of finished trials: {}".format(len(study.trials)))
|
|
|
|
print("Best trial:")
|
|
trial = study.best_trial
|
|
|
|
print(" Value: {}".format(trial.value))
|
|
|
|
print(" Params: ")
|
|
for key, value in trial.params.items():
|
|
print(" {}: {}".format(key, value))
|
|
|
|
optuna_study_object = Path('study') / 'study.pkl'
|
|
optuna_study_object.parent.mkdir(exist_ok=True)
|
|
with optuna_study_object.open(mode='wb') as f:
|
|
pickle.dump(study, f)
|