Python 3.8 branch merged

some small template fixes
This commit is contained in:
Steffen Illium 2020-05-17 22:11:21 +02:00
parent fc93f71608
commit e423d6fe31
2 changed files with 6 additions and 12 deletions

View File

@ -44,11 +44,6 @@ def run_lightning_loop(config_obj):
# Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights(torch.nn.init.xavier_normal_)
if model.name == 'CNNRouteGeneratorDiscriminated':
# ToDo: Make this dependent on the used seed
path = logger.outpath / 'classifier_cnn' / 'version_0'
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
# Trainer
# =============================================================================
@ -70,8 +65,8 @@ def run_lightning_loop(config_obj):
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
model.save_to_disk(config_obj.exp_path)
# Evaluate It
if config_obj.main.eval:

View File

@ -1,6 +1,6 @@
import warnings
from utils.config import Config
from _templates.new_project.utils.project_config import Config
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -15,10 +15,9 @@ if __name__ == '__main__':
# Model Settings
config = Config().read_namespace(args)
# bias, activation, model, norm, max_epochs, filters
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, model_activation='leaky_relu',
model_type='classifier_cnn', model_filters=[16, 32, 64], data_batchsize=512)
# bias, activation, model, norm, max_epochs, sr, feature_mixed_dim, filters
# bias, activation, model, norm, max_epochs
cnn_classifier = dict(train_epochs=10, model_use_bias=True, model_use_norm=True, data_batchsize=512)
# bias, activation, model, norm, max_epochs
for arg_dict in [cnn_classifier]:
for seed in range(5):