Python 3.8 branch merged

some small template fixes
This commit is contained in:
Steffen Illium
2020-05-17 22:11:21 +02:00
parent fc93f71608
commit e423d6fe31
2 changed files with 6 additions and 12 deletions

View File

@ -44,11 +44,6 @@ def run_lightning_loop(config_obj):
# Init
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
model.init_weights(torch.nn.init.xavier_normal_)
if model.name == 'CNNRouteGeneratorDiscriminated':
# ToDo: Make this dependent on the used seed
path = logger.outpath / 'classifier_cnn' / 'version_0'
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
# Trainer
# =============================================================================
@ -70,8 +65,8 @@ def run_lightning_loop(config_obj):
trainer.fit(model)
# Save the last state & all parameters
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
model.save_to_disk(logger.log_dir)
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
model.save_to_disk(config_obj.exp_path)
# Evaluate It
if config_obj.main.eval: