restructured

This commit is contained in:
Steffen Illium
2020-03-11 22:15:40 +01:00
parent 2a6100296f
commit bb8151d9ba
2 changed files with 3 additions and 2 deletions

View File

@ -34,7 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
loss = (kld_loss + element_wise_loss) / 2
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))

View File

@ -48,6 +48,7 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
main_arg_parser.add_argument("--train_epochs", type=int, default=20, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
@ -109,7 +110,7 @@ def run_lightning_loop(config_obj):
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
check_val_every_n_epoch=1,
num_sanity_val_steps=0,
num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,