From bb8151d9ba549dde976054f8d7e8a06ba5307f05 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Wed, 11 Mar 2020 22:15:40 +0100 Subject: [PATCH] restructured --- lib/models/generators/cnn.py | 2 +- main.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index 6167b6e..6e2ad6c 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -34,7 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing TODO: Does This make sense? Sanity Check it! # kld_loss /= reduce(mul, self.in_shape) - kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size + kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 loss = (kld_loss + element_wise_loss) / 2 return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) diff --git a/main.py b/main.py index 7e4e522..f5b5a23 100644 --- a/main.py +++ b/main.py @@ -48,6 +48,7 @@ main_arg_parser.add_argument("--train_version", type=strtobool, required=False, main_arg_parser.add_argument("--train_epochs", type=int, default=20, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="") main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") +main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") # Model main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") @@ -109,7 +110,7 @@ def run_lightning_loop(config_obj): weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, check_val_every_n_epoch=1, - num_sanity_val_steps=0, + num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting checkpoint_callback=checkpoint_callback,