paper preperations and notebooks, optuna callbacks, new plots
This commit is contained in:
parent
3955f5ccd0
commit
64ef0386d5
136
datasets/urban_8k.py
Normal file
136
datasets/urban_8k.py
Normal file
@ -0,0 +1,136 @@
|
||||
from pathlib import Path
|
||||
|
||||
import multiprocessing as mp
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
from torchvision.transforms import RandomApply, ToTensor, Compose
|
||||
|
||||
from ml_lib.audio_toolset.audio_io import NormalizeLocal
|
||||
from ml_lib.audio_toolset.audio_to_mel_dataset import LibrosaAudioToMelDataset
|
||||
from ml_lib.audio_toolset.mel_augmentation import NoiseInjection, LoudnessManipulator, ShiftTime, MaskAug
|
||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_devel, DATA_OPTION_train, DATA_OPTION_test
|
||||
|
||||
data_options = [DATA_OPTION_test, DATA_OPTION_train, DATA_OPTION_devel]
|
||||
|
||||
try:
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class UrbanSound8K(_BaseDataModule):
|
||||
|
||||
_class_names = ['air_conditioner', 'car_horn', 'children_playing', 'dog_bark', 'drilling',
|
||||
'engine_idling', 'gun_shot', 'jackhammer', 'siren', 'street_music']
|
||||
|
||||
@property
|
||||
def class_names(self):
|
||||
return {key: val for val, key in enumerate(self._class_names)}
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return len(self.class_names)
|
||||
|
||||
@property
|
||||
def sample_shape(self):
|
||||
return self[0][1].shape
|
||||
|
||||
# Data Structures
|
||||
@property
|
||||
def mel_folder(self):
|
||||
return self.data_root / 'mel'
|
||||
|
||||
@property
|
||||
def wav_folder(self):
|
||||
return self.data_root / self._wav_folder_name
|
||||
|
||||
@property
|
||||
def _container_ext(self):
|
||||
return '.mel'
|
||||
|
||||
def __init__(self, data_root, num_worker, sr, n_mels, n_fft, hop_length, sampler=None,
|
||||
random_apply_chance=0.5, target_mel_length_in_seconds=1, fold=1, setting=DATA_OPTION_train,
|
||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3, rebuild=False):
|
||||
assert isinstance(setting, str), f'Setting has to be a string, but was: {type(setting)}.'
|
||||
assert fold in range(1, 11) if isinstance(fold, int) else all([f in range(1, 11) for f in fold])
|
||||
super(UrbanSound8K, self).__init__()
|
||||
self.num_worker = num_worker or 1
|
||||
self.sampler = sampler
|
||||
|
||||
# Dataset Paramters
|
||||
self.fold = fold if isinstance(fold, list) else [fold]
|
||||
|
||||
# Dataset Parameters
|
||||
self.data_root = Path(data_root) / self.__class__.__name__
|
||||
self._wav_folder_name = 'audio'
|
||||
|
||||
self.mel_length_in_seconds = target_mel_length_in_seconds
|
||||
|
||||
# Mel Transforms - will be pushed with all other paramters by self.__dict__ to subdataset-class
|
||||
self.mel_kwargs = dict(sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
|
||||
|
||||
target_frames = self.mel_length_in_seconds * self.mel_kwargs['sr']
|
||||
self.sample_segment_length = target_frames // self.mel_kwargs['hop_length'] + 1
|
||||
|
||||
# Utility
|
||||
self.utility_transforms = Compose([
|
||||
NormalizeLocal(),
|
||||
ToTensor()
|
||||
])
|
||||
|
||||
# Data Augmentations
|
||||
self.random_apply_chance = random_apply_chance
|
||||
self.mel_augmentations = Compose([
|
||||
RandomApply([NoiseInjection(noise_ratio)], p=random_apply_chance),
|
||||
RandomApply([LoudnessManipulator(loudness_ratio)], p=random_apply_chance),
|
||||
RandomApply([ShiftTime(shift_ratio)], p=random_apply_chance),
|
||||
RandomApply([MaskAug(mask_ratio)], p=random_apply_chance),
|
||||
self.utility_transforms])
|
||||
|
||||
# Find all raw files and turn generator to persistent list:
|
||||
self._wav_files = list(self.wav_folder.rglob('*.wav'))
|
||||
|
||||
# Build the Dataset
|
||||
self._dataset = self._build_dataset(rebuild)
|
||||
|
||||
def _build_subdataset(self, row, build=False):
|
||||
slice_file_name, fs_id, start, end, salience, fold, class_id, class_name = row.strip().split(',')
|
||||
fold, class_id = (int(x) for x in (fold, class_id))
|
||||
if int(fold) in self.fold:
|
||||
audio_file_path = self.wav_folder / f'fold{fold}' / slice_file_name
|
||||
kwargs = dict(sample_segment_len=self.sample_segment_length,
|
||||
sample_hop_len=self.sample_segment_length // 2)
|
||||
mel_dataset = LibrosaAudioToMelDataset(audio_file_path, class_id, mel_kwargs=self.mel_kwargs, **kwargs)
|
||||
if build:
|
||||
assert mel_dataset.build_mel()
|
||||
return mel_dataset, class_id, slice_file_name
|
||||
else:
|
||||
return None
|
||||
|
||||
def _build_dataset(self, build=False):
|
||||
dataset = list()
|
||||
with open(Path(self.data_root) / 'metadata' / 'UrbanSound8K.csv', mode='r') as f:
|
||||
# Exclude the header
|
||||
_ = next(f)
|
||||
all_rows = list(f)
|
||||
chunksize = len(all_rows) // max(self.num_worker, 1)
|
||||
with mp.Pool(processes=self.num_worker) as pool:
|
||||
from itertools import repeat
|
||||
results = pool.starmap_async(self._build_subdataset,
|
||||
zip(all_rows,
|
||||
repeat(build, len(all_rows))
|
||||
),
|
||||
chunksize=chunksize)
|
||||
for sub_dataset in results.get():
|
||||
if sub_dataset is not None:
|
||||
if sub_dataset[0] is not None:
|
||||
dataset.append(sub_dataset[0])
|
||||
return ConcatDataset(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._dataset)
|
||||
|
||||
def __getitem__(self, item):
|
||||
file_path, transformed_samples, label = self._dataset[item]
|
||||
label = torch.as_tensor(label, dtype=torch.int)
|
||||
return file_path, transformed_samples, label
|
62
datasets/urban_8k_datamodule.py
Normal file
62
datasets/urban_8k_datamodule.py
Normal file
@ -0,0 +1,62 @@
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets.urban_8k import UrbanSound8K
|
||||
from ml_lib.utils._basedatamodule import _BaseDataModule, DATA_OPTION_train, DATA_OPTION_devel
|
||||
from ml_lib.utils.tools import add_argparse_args
|
||||
|
||||
|
||||
class Urban8KLibrosaDatamodule(_BaseDataModule):
|
||||
|
||||
def __init__(self, batch_size, num_worker, data_root, sr, n_mels, n_fft, hop_length,
|
||||
sampler=None, val_fold=9,
|
||||
random_apply_chance=0.5, target_mel_length_in_seconds=1,
|
||||
loudness_ratio=0.3, shift_ratio=0.3, noise_ratio=0.3, mask_ratio=0.3, **kwargs):
|
||||
super(Urban8KLibrosaDatamodule, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.num_worker = num_worker
|
||||
|
||||
self.val_fold = val_fold
|
||||
|
||||
self.kwargs = kwargs
|
||||
self.kwargs.update(data_root=data_root, num_worker=num_worker,
|
||||
sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, sampler=sampler,
|
||||
random_apply_chance=random_apply_chance,
|
||||
target_mel_length_in_seconds=target_mel_length_in_seconds,
|
||||
loudness_ratio=loudness_ratio, shift_ratio=shift_ratio, noise_ratio=noise_ratio,
|
||||
mask_ratio=mask_ratio)
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser):
|
||||
return add_argparse_args(UrbanSound8K, parent_parser)
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
val_fold = kwargs.get('val_fold', 10)
|
||||
kwargs.update(val_fold=val_fold)
|
||||
return super(Urban8KLibrosaDatamodule, cls).from_argparse_args(args, **kwargs)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_train], shuffle=True,
|
||||
batch_size=self.batch_size, pin_memory=True,
|
||||
num_workers=self.num_worker)
|
||||
|
||||
# Validation Dataloader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(dataset=self.datasets[DATA_OPTION_devel], shuffle=False, pin_memory=True,
|
||||
batch_size=self.batch_size, num_workers=self.num_worker)
|
||||
|
||||
def prepare_data(self, stag=None):
|
||||
# Train Datasset
|
||||
self.datasets[DATA_OPTION_train] = UrbanSound8K(fold=[x for x in list(range(1, 11)) if x != self.val_fold],
|
||||
**self.kwargs)
|
||||
# Devel Datasset
|
||||
self.datasets[DATA_OPTION_devel] = UrbanSound8K(fold=self.val_fold, **self.kwargs)
|
||||
|
||||
def manual_setup(self):
|
||||
UrbanSound8K(fold=[x for x in list(range(1, 11))], rebuild=True, **self.kwargs)
|
||||
|
||||
|
||||
|
14
main.py
14
main.py
@ -8,7 +8,7 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
|
||||
from ml_lib.utils.callbacks import BestScoresCallback
|
||||
from ml_lib.utils.config import parse_comandline_args_add_defaults
|
||||
from ml_lib.utils.loggers import Logger
|
||||
from ml_lib.utils.loggers import LightningLogger
|
||||
|
||||
import variables as v
|
||||
from ml_lib.utils.tools import fix_all_random_seeds
|
||||
@ -21,7 +21,7 @@ def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, ad
|
||||
|
||||
fix_all_random_seeds(seed)
|
||||
|
||||
with Logger.from_argparse_args(h_params) as logger:
|
||||
with LightningLogger.from_argparse_args(h_params) as logger:
|
||||
# Callbacks
|
||||
# =============================================================================
|
||||
# Checkpoint Saving
|
||||
@ -53,12 +53,18 @@ def run_lightning_loop(h_params :Namespace, data_class, model_class, seed=69, ad
|
||||
|
||||
# START
|
||||
# =============================================================================
|
||||
# Let Datamodule pull what it wants
|
||||
# Let Datamodule pull what it wants and init
|
||||
datamodule = data_class.from_argparse_args(h_params)
|
||||
|
||||
# Final h_params Setup:
|
||||
h_params = vars(h_params)
|
||||
h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes)
|
||||
try:
|
||||
h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes)
|
||||
except KeyError:
|
||||
datamodule.manual_setup()
|
||||
datamodule.prepare_data()
|
||||
h_params.update(in_shape=datamodule.shape, n_classes=datamodule.n_classes)
|
||||
|
||||
h_params = Namespace(**h_params)
|
||||
|
||||
# Let Trainer pull what it wants and add callbacks
|
||||
|
50
multi_run.py
50
multi_run.py
@ -10,20 +10,20 @@ import itertools
|
||||
if __name__ == '__main__':
|
||||
|
||||
# Set new values
|
||||
hparams_dict = dict(seed=range(1, 6),
|
||||
hparams_dict = dict(seed=range(13, 20),
|
||||
# BandwiseConvClassifier, CNNBaseline, VisualTransformer, VerticalVisualTransformer
|
||||
model_name=['VisualTransformer'],
|
||||
model_name=['BandwiseConvClassifier'],
|
||||
# CCSLibrosaDatamodule, PrimatesLibrosaDatamodule,
|
||||
data_name=['PrimatesLibrosaDatamodule'],
|
||||
batch_size=[30],
|
||||
max_epochs=[150],
|
||||
target_mel_length_in_seconds=[0.5],
|
||||
outpath=['head_exp'],
|
||||
batch_size=[20],
|
||||
max_epochs=[200],
|
||||
target_mel_length_in_seconds=[0.4],
|
||||
outpath=['optuna_found_param_run'],
|
||||
|
||||
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||
dropout=[0.0], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||
|
||||
scheduler=['LambdaLR'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
||||
lr_scheduler_parameter=[0.95], # [0.95],
|
||||
scheduler=[None], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
|
||||
lr_scheduler_parameter=[None], # [0.95],
|
||||
|
||||
loss=['ce_loss'],
|
||||
sampler=['WeightedRandomSampler'],
|
||||
@ -32,29 +32,29 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
# Data Aug Parameters
|
||||
hparams_dict.update(random_apply_chance=[0.3], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
shift_ratio=[0.2], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||
hparams_dict.update(random_apply_chance=[0.1], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=[0.2], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||
noise_ratio=[0.4], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||
mask_ratio=[0.2], # triaSl.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),)
|
||||
mask_ratio=[0.3], # triaSl.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),)
|
||||
)
|
||||
if False:
|
||||
if True:
|
||||
# CNN Parameters:
|
||||
hparams_dict.update(filters=[[16, 32, 64, 32]],
|
||||
lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||
hparams_dict.update(filters=[[6, 6, 6]],
|
||||
lr=[0.0003414550170649836], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||
variable_length=[False], # THIS does not Work
|
||||
lat_dim=[64], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
lat_dim=[2 ** 3], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
)
|
||||
else:
|
||||
# Transfornmer Parameters:
|
||||
hparams_dict.update(lr=[1e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
mlp_dim=[16],
|
||||
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
|
||||
attn_depth=[14], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
||||
heads=[2,4,6,8,10], # trial.suggest_int('heads', 2, 16, step=2),
|
||||
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
||||
hparams_dict.update(lr=[0.0008292481039683588], # trial.suggest_uniform('lr', 1e-3, 3e-3),
|
||||
lat_dim=[2**4], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
mlp_dim=[2**4],
|
||||
head_dim=[2**4], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||
patch_size=[6], # trial.suggest_int('patch_size', 6, 12, step=3),
|
||||
attn_depth=[10], # trial.suggest_int('attn_depth', 2, 14, step=4),
|
||||
heads=[16], # trial.suggest_int('heads', 2, 16, step=2),
|
||||
embedding_size=[60], # trial.suggest_int('embedding_size', 12, 64, step=12),
|
||||
variable_length=[False], # THIS does not Work
|
||||
)
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -36,6 +36,7 @@
|
||||
"out_path = Path('..') / Path('output')\n",
|
||||
"_model_name = 'VisualTransformer'\n",
|
||||
"_dataset_name = 'PrimatesLibrosaDatamodule'\n",
|
||||
"breakpoint()\n",
|
||||
"_param_name = 'heads'"
|
||||
],
|
||||
"metadata": {
|
||||
|
288
notebooks/best_scores.ipynb
Normal file
288
notebooks/best_scores.ipynb
Normal file
File diff suppressed because one or more lines are too long
159
optuna_tune.py
159
optuna_tune.py
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import optuna as optuna
|
||||
from natsort import natsorted
|
||||
from optuna.integration import PyTorchLightningPruningCallback
|
||||
|
||||
from main import run_lightning_loop
|
||||
@ -18,7 +19,7 @@ class ContiniousSavingCallback:
|
||||
|
||||
@property
|
||||
def tmp_study_path(self):
|
||||
return Path(self.root) / f'TMP_{self.study.study_name}_trial{self.study.trials[-1].number}.pkl'
|
||||
return Path(self.root) / f'TMP_{self.study.study_name}_trial_{self.study.trials[-1].number}.pkl'
|
||||
|
||||
@property
|
||||
def final_study_path(self):
|
||||
@ -41,7 +42,7 @@ class ContiniousSavingCallback:
|
||||
pickle.dump(object, f)
|
||||
|
||||
def save_final(self):
|
||||
self._write_to_disk(self.study, self.final_study_path())
|
||||
self._write_to_disk(self.study, self.final_study_path)
|
||||
|
||||
def clean_up(self):
|
||||
temp_study_files = self.root.glob(f'TMP_{self.study.study_name}*')
|
||||
@ -59,75 +60,112 @@ class ContiniousSavingCallback:
|
||||
self.clean_up()
|
||||
|
||||
|
||||
def optimize(trial: optuna.Trial):
|
||||
# Optuna configuration
|
||||
folder = Path('study')
|
||||
folder.mkdir(parents=False, exist_ok=True)
|
||||
scheduler = trial.suggest_categorical('scheduler', [None, 'LambdaLR'])
|
||||
if scheduler is not None:
|
||||
lr_scheduler_parameter = trial.suggest_float('lr_scheduler_parameter', 0.8, 1, step=0.01)
|
||||
else:
|
||||
lr_scheduler_parameter = None
|
||||
class Objective(object):
|
||||
def __init__(self, model_class_name, data_class_name, max_epochs, loss):
|
||||
self.study_params = dict(model_name=model_class_name,
|
||||
data_name=data_class_name,
|
||||
max_epochs=max_epochs,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
optuna_suggestions = dict(
|
||||
model_name='CNNBaseline',
|
||||
data_name='PrimatesLibrosaDatamodule',
|
||||
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
|
||||
max_epochs=400,
|
||||
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
|
||||
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
shift_ratio=trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||
noise_ratio=trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||
mask_ratio=trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
||||
lr=trial.suggest_loguniform('lr', 1e-5, 1e-3),
|
||||
dropout=trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||
lat_dim=2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
scheduler=scheduler,
|
||||
lr_scheduler_parameter=lr_scheduler_parameter,
|
||||
loss='ce_loss',
|
||||
sampler=trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
||||
study_name=trial.study.study_name
|
||||
)
|
||||
if optuna_suggestions['model_name'] == 'CNNBaseline':
|
||||
model_depth = trial.suggest_int('model_depth', 1, 6, step=1)
|
||||
filters = list()
|
||||
for layer_idx in range(model_depth):
|
||||
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
|
||||
optuna_suggestions.update(filters=filters)
|
||||
elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']:
|
||||
transformer_dict = dict(
|
||||
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
||||
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||
patch_size=trial.suggest_int('patch_size', 6, 20, step=3),
|
||||
attn_depth=trial.suggest_int('attn_depth', 2, 20, step=4),
|
||||
heads=trial.suggest_int('heads', 2, 16, step=2),
|
||||
embedding_size=trial.suggest_int('embedding_size', 12, 64, step=12)
|
||||
def __call__(self, trial):
|
||||
# Optuna configuration
|
||||
folder = Path('study')
|
||||
folder.mkdir(parents=False, exist_ok=True)
|
||||
|
||||
# Suggested Parameters:
|
||||
scheduler = trial.suggest_categorical('scheduler', [None, 'LambdaLR'])
|
||||
if scheduler is not None:
|
||||
lr_scheduler_parameter = trial.suggest_float('lr_scheduler_parameter', 0.8, 1, step=0.01)
|
||||
else:
|
||||
lr_scheduler_parameter = None
|
||||
|
||||
optuna_suggestions = dict(
|
||||
batch_size=trial.suggest_int('batch_size', 5, 50, step=5),
|
||||
target_mel_length_in_seconds=trial.suggest_float('target_mel_length_in_seconds', 0.2, 1.5, step=0.1),
|
||||
random_apply_chance=trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
|
||||
loudness_ratio=trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
|
||||
shift_ratio=trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
|
||||
noise_ratio=trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
|
||||
mask_ratio=trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
|
||||
lr=trial.suggest_loguniform('lr', 1e-5, 1e-3),
|
||||
dropout=trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
|
||||
lat_dim=2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
|
||||
scheduler=scheduler,
|
||||
lr_scheduler_parameter=lr_scheduler_parameter,
|
||||
sampler=trial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
|
||||
)
|
||||
optuna_suggestions.update(**transformer_dict)
|
||||
|
||||
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="PL_recall_score")
|
||||
# User defined Parameters:
|
||||
for params_name in self.study_params.keys():
|
||||
trial.set_user_attr(params_name, self.study_params[params_name])
|
||||
trial.set_user_attr('study_name', trial.study.study_name)
|
||||
optuna_suggestions.update(**trial.user_attrs)
|
||||
|
||||
if optuna_suggestions['model_name'] in ['CNNBaseline', 'BandwiseConvClassifier']:
|
||||
|
||||
model_depth = trial.suggest_int('model_depth', 1, 6, step=1)
|
||||
filters = list()
|
||||
for layer_idx in range(model_depth):
|
||||
filters.append(2 ** trial.suggest_int(f'filters_{layer_idx}', 2, 6, step=1))
|
||||
optuna_suggestions.update(filters=filters)
|
||||
elif optuna_suggestions['model_name'] in ['VisualTransformer', 'VerticalVisualTransformer']:
|
||||
transformer_dict = dict(
|
||||
mlp_dim=2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
|
||||
head_dim=2 ** trial.suggest_int('head_dim', 1, 5, step=1),
|
||||
patch_size=trial.suggest_int('patch_size', 6, 20, step=3),
|
||||
attn_depth=trial.suggest_int('attn_depth', 2, 20, step=4),
|
||||
heads=trial.suggest_int('heads', 2, 16, step=2),
|
||||
embedding_size=trial.suggest_int('embedding_size', 12, 64, step=12)
|
||||
)
|
||||
optuna_suggestions.update(**transformer_dict)
|
||||
|
||||
pruning_callback = PyTorchLightningPruningCallback(trial, monitor="PL_recall_score")
|
||||
|
||||
# Parse comandline args, read config and get model
|
||||
h_params, found_data_class, found_model_class, seed = parse_comandline_args_add_defaults(
|
||||
'_parameters.ini', overrides=optuna_suggestions)
|
||||
h_params = Namespace(**h_params)
|
||||
|
||||
# Parse comandline args, read config and get model
|
||||
h_params, found_data_class, found_model_class, seed = parse_comandline_args_add_defaults(
|
||||
'_parameters.ini', overrides=optuna_suggestions)
|
||||
h_params = Namespace(**h_params)
|
||||
try:
|
||||
results = run_lightning_loop(h_params, data_class=found_data_class, model_class=found_model_class,
|
||||
additional_callbacks=pruning_callback, seed=seed)
|
||||
additional_callbacks=pruning_callback, seed=seed)
|
||||
best_score = results.best_model_score
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
best_score = 0
|
||||
return best_score
|
||||
return best_score
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
optuna_study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
||||
with ContiniousSavingCallback('study', optuna_study) as continious_save_callback:
|
||||
# Study Parameters
|
||||
out_folder = Path('study')
|
||||
model_name = 'CNNBaseline'
|
||||
data_name = 'Urban8KLibrosaDatamodule'
|
||||
loss = 'ce_loss'
|
||||
max_epochs = 200
|
||||
n_trials = 400
|
||||
study_name = f'{model_name}_{max_epochs}_{data_name}'
|
||||
|
||||
# Create Study or load study:
|
||||
try:
|
||||
found_studys = [x for x in out_folder.iterdir() if study_name in x.name]
|
||||
except FileNotFoundError:
|
||||
found_studys = []
|
||||
if found_studys:
|
||||
latest_found_study = natsorted(found_studys, key=lambda x: x.stem[x.stem.find('_trial'):])[-1]
|
||||
with latest_found_study.open('rb') as latest_found_study_file:
|
||||
optuna_study = pickle.load(latest_found_study_file)
|
||||
n_trials = n_trials - len(optuna_study.trials)
|
||||
print(f'An old study has been found and loaded: {optuna_study.study_name}')
|
||||
else:
|
||||
print(f'A new Study will be created: {study_name}')
|
||||
optuna_study = optuna.create_study(study_name=study_name,
|
||||
direction='maximize', sampler=optuna.samplers.TPESampler(seed=1337))
|
||||
n_trials = n_trials
|
||||
# Optimize it
|
||||
with ContiniousSavingCallback(out_folder, optuna_study) as continious_save_callback:
|
||||
# study.optimize(optimize, n_trials=50, callbacks=[opt_utils.NeptuneCallback(log_study=True, log_charts=True)])
|
||||
optuna_study.optimize(optimize, n_trials=200, show_progress_bar=True, callbacks=[continious_save_callback])
|
||||
optuna_study.optimize(Objective(model_name, data_name, max_epochs, loss), n_trials=n_trials,
|
||||
show_progress_bar=True,
|
||||
callbacks=[continious_save_callback], catch=(Exception, ))
|
||||
|
||||
print("Number of finished trials: {}".format(len(optuna_study.trials)))
|
||||
|
||||
@ -139,3 +177,4 @@ if __name__ == '__main__':
|
||||
print(" Params: ")
|
||||
for key, value in trial.params.items():
|
||||
print(" {}: {}".format(key, value))
|
||||
exit()
|
||||
|
Loading…
x
Reference in New Issue
Block a user