Python 3.8 branch merged
some small template fixes
This commit is contained in:
@ -44,11 +44,6 @@ def run_lightning_loop(config_obj):
|
||||
# Init
|
||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||
model.init_weights(torch.nn.init.xavier_normal_)
|
||||
if model.name == 'CNNRouteGeneratorDiscriminated':
|
||||
# ToDo: Make this dependent on the used seed
|
||||
path = logger.outpath / 'classifier_cnn' / 'version_0'
|
||||
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
||||
model.set_discriminator(disc_model)
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
@ -70,8 +65,8 @@ def run_lightning_loop(config_obj):
|
||||
trainer.fit(model)
|
||||
|
||||
# Save the last state & all parameters
|
||||
trainer.save_checkpoint(logger.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(logger.log_dir)
|
||||
trainer.save_checkpoint(config_obj.exp_path.log_dir / 'weights.ckpt')
|
||||
model.save_to_disk(config_obj.exp_path)
|
||||
|
||||
# Evaluate It
|
||||
if config_obj.main.eval:
|
||||
|
Reference in New Issue
Block a user