Final Train Runs

This commit is contained in:
Steffen Illium
2021-03-18 07:45:07 +01:00
parent ad254dae92
commit fecf4923c2
14 changed files with 672 additions and 362 deletions

View File

@ -10,23 +10,30 @@ import itertools
if __name__ == '__main__':
# Set new values
hparams_dict = dict(model_name=['VisualTransformer'],
max_epochs=[150],
hparams_dict = dict(seed=range(10),
model_name=['VisualTransformer'],
batch_size=[50],
random_apply_chance=[0.5],
loudness_ratio=[0],
shift_ratio=[0.3],
noise_ratio=[0.3],
mask_ratio=[0.3],
lr=[0.001],
dropout=[0.2],
lat_dim=[32, 64],
patch_size=[8, 12],
attn_depth=[12],
heads=[6],
embedding_size=[16, 32],
max_epochs=[250],
random_apply_chance=[0.3], # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
loudness_ratio=[0], # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
shift_ratio=[0.3], # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
noise_ratio=[0.3], # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
mask_ratio=[0.3], # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
lr=[5e-3], # trial.suggest_uniform('lr', 1e-3, 3e-3),
dropout=[0.2], # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
lat_dim=[32], # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
mlp_dim=[16], # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
head_dim=[6], # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
patch_size=[12], # trial.suggest_int('patch_size', 6, 12, step=3),
attn_depth=[10], # trial.suggest_int('attn_depth', 2, 14, step=4),
heads=[6], # trial.suggest_int('heads', 2, 16, step=2),
scheduler=['CosineAnnealingWarmRestarts'], # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
lr_scheduler_parameter=[25], # [0.98],
embedding_size=[30], # trial.suggest_int('embedding_size', 12, 64, step=12),
loss=['ce_loss'],
sampler=['WeightedRandomSampler']
sampler=['WeightedRandomSampler'],
# rial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
weight_decay=[0], # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
)
keys, values = zip(*hparams_dict.items())