56 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			56 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from argparse import Namespace
 | |
| 
 | |
| from tqdm import tqdm
 | |
| 
 | |
| from main import run_lightning_loop
 | |
| from ml_lib.utils.config import parse_comandline_args_add_defaults
 | |
| 
 | |
| import itertools
 | |
| 
 | |
| if __name__ == '__main__':
 | |
| 
 | |
|     # Set new values
 | |
|     hparams_dict = dict(seed=range(10),
 | |
|                         model_name=['VisualTransformer'],
 | |
|                         batch_size=[50],
 | |
|                         max_epochs=[200],
 | |
|                         random_apply_chance=[0.3],  # trial.suggest_float('random_apply_chance', 0.1, 0.5, step=0.1),
 | |
|                         loudness_ratio=[0],  # trial.suggest_float('loudness_ratio', 0.0, 0.5, step=0.1),
 | |
|                         shift_ratio=[0.3],  # trial.suggest_float('shift_ratio', 0.0, 0.5, step=0.1),
 | |
|                         noise_ratio=[0.3],  # trial.suggest_float('noise_ratio', 0.0, 0.5, step=0.1),
 | |
|                         mask_ratio=[0.3],  # trial.suggest_float('mask_ratio', 0.0, 0.5, step=0.1),
 | |
|                         lr=[2e-3],  # trial.suggest_uniform('lr', 1e-3, 3e-3),
 | |
|                         dropout=[0.2],  # trial.suggest_float('dropout', 0.0, 0.3, step=0.05),
 | |
|                         lat_dim=[32],  # 2 ** trial.suggest_int('lat_dim', 1, 5, step=1),
 | |
|                         mlp_dim=[16],  # 2 ** trial.suggest_int('mlp_dim', 1, 5, step=1),
 | |
|                         head_dim=[6],  # 2 ** trial.suggest_int('head_dim', 1, 5, step=1),
 | |
|                         patch_size=[12],  # trial.suggest_int('patch_size', 6, 12, step=3),
 | |
|                         attn_depth=[10],  # trial.suggest_int('attn_depth', 2, 14, step=4),
 | |
|                         heads=[6],  # trial.suggest_int('heads', 2, 16, step=2),
 | |
|                         scheduler=['CosineAnnealingWarmRestarts'],  # trial.suggest_categorical('scheduler', [None, 'LambdaLR']),
 | |
|                         lr_scheduler_parameter=[5],  # [0.98],
 | |
|                         embedding_size=[30],  # trial.suggest_int('embedding_size', 12, 64, step=12),
 | |
|                         loss=['ce_loss'],
 | |
|                         sampler=['WeightedRandomSampler'],
 | |
|                         # rial.suggest_categorical('sampler', [None, 'WeightedRandomSampler']),
 | |
|                         weight_decay=[0],  # trial.suggest_loguniform('weight_decay', 1e-20, 1e-1),
 | |
|                         )
 | |
| 
 | |
|     keys, values = zip(*hparams_dict.items())
 | |
|     permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
 | |
|     for permutations_dict in tqdm(permutations_dicts, total=len(permutations_dicts)):
 | |
|         # Parse comandline args, read config and get model
 | |
|         cmd_args, found_data_class, found_model_class = parse_comandline_args_add_defaults(
 | |
|             '_parameters.ini', overrides=permutations_dict)
 | |
| 
 | |
|         hparams = dict(**cmd_args)
 | |
|         hparams.update(permutations_dict)
 | |
|         hparams = Namespace(**hparams)
 | |
| 
 | |
|         # RUN
 | |
|         # ---------------------------------------
 | |
|         print(f'Running Loop, parameters are: {permutations_dict}')
 | |
|         run_lightning_loop(hparams, found_data_class, found_model_class)
 | |
|         print(f'Done, parameters were: {permutations_dict}')
 | |
|         pass
 | 
