train running dataset fixed

This commit is contained in:
steffen
2020-03-05 20:50:07 +01:00
parent 1f25bf599b
commit 05033bed75
11 changed files with 41 additions and 49 deletions

View File

@ -8,7 +8,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from main import run_training, args
from main import run_lightning_loop, args
if __name__ == '__main__':
@ -16,17 +16,14 @@ if __name__ == '__main__':
# Model Settings
config = Config().read_namespace(args)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
cnn_classifier = dict(train_epochs=100, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for arg_dict in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
arg_dict.update(main_seed=seed)
config = config.update(arg_dict)
run_training(config)
run_lightning_loop(config)