restructured
This commit is contained in:
@ -34,7 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
|
||||
|
||||
loss = (kld_loss + element_wise_loss) / 2
|
||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||
|
3
main.py
3
main.py
@ -48,6 +48,7 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False,
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=20, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||
@ -109,7 +110,7 @@ def run_lightning_loop(config_obj):
|
||||
weights_save_path=logger.log_dir,
|
||||
gpus=[0] if torch.cuda.is_available() else None,
|
||||
check_val_every_n_epoch=1,
|
||||
num_sanity_val_steps=0,
|
||||
num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
|
||||
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
|
||||
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
|
Reference in New Issue
Block a user