eval written

This commit is contained in:
Si11ium
2020-03-05 16:58:23 +01:00
parent 8d06c179c9
commit 1f25bf599b
12 changed files with 127 additions and 74 deletions

View File

@ -7,29 +7,26 @@ warnings.filterwarnings('ignore', category=UserWarning)
# Imports
# =============================================================================
from pathlib import Path
import os
from main import run_training, args
if __name__ == '__main__':
# Model Settings
warnings.filterwarnings('ignore', category=FutureWarning)
config = Config().read_namespace(args)
# use_bias, activation, model, use_norm, max_epochs, filters
cnn_classifier = [True, 'leaky_relu', 'classifier_cnn', False, 2, [16, 32, 64]]
# use_bias, activation, model, use_norm, max_epochs, sr, lat_dim, filters
# Data Settings
data_shortcodes = ['mid', 'mid_5']
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
# Iteration over
for data_shortcode in data_shortcodes:
for use_bias, activation, model, use_norm, max_epochs, filters in [cnn_classifier]:
for seed in range(5):
arg_dict = dict(main_seed=seed, train_max_epochs=max_epochs,
model_use_bias=use_bias, model_use_norm=use_norm,
model_activation=activation, model_type=model,
model_filters=filters,
data_batch_size=512)
config = config.update(arg_dict)
os.system(f'/home/steffen/envs/traj_gen/bin/python main.py {arg_dict}')
run_training(config)