diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index 076f1a6..c2897e5 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -40,7 +40,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing TODO: Does This make sense? Sanity Check it! # kld_loss /= reduce(mul, self.in_shape) - # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size + kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size loss = kld_loss + element_wise_loss else: @@ -131,7 +131,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): ################################################### # # Utils - self.activation = nn.ReLU() + self.activation = nn.LeakyReLU() self.sigmoid = nn.Sigmoid() # @@ -149,12 +149,8 @@ class CNNRouteGeneratorModel(LightningBaseModule): conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[1], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, + self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) @@ -162,12 +158,13 @@ class CNNRouteGeneratorModel(LightningBaseModule): conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0, + + self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - last_conv_shape = self.enc_conv_2b.shape + last_conv_shape = self.enc_conv_3a.shape self.enc_flat = Flatten(last_conv_shape) self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim) @@ -196,14 +193,18 @@ class CNNRouteGeneratorModel(LightningBaseModule): self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape) self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2], - conv_padding=1, conv_kernel=9, conv_stride=1, + conv_padding=0, conv_kernel=7, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1], - conv_padding=1, conv_kernel=7, conv_stride=1, + conv_padding=1, conv_kernel=5, conv_stride=1, use_norm=self.hparams.model_param.use_norm) - self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None, + self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0], + conv_padding=0, conv_kernel=3, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + + self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None, conv_padding=0, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) @@ -229,23 +230,15 @@ class CNNRouteGeneratorModel(LightningBaseModule): def encode(self, batch_x): combined_tensor = self.enc_conv_0(batch_x) - combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor combined_tensor = self.enc_conv_1a(combined_tensor) - combined_tensor = self.enc_conv_1b(combined_tensor) - combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor combined_tensor = self.enc_conv_2a(combined_tensor) - combined_tensor = self.enc_conv_2b(combined_tensor) - # combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor - # combined_tensor = self.enc_conv_3a(combined_tensor) - # combined_tensor = self.enc_conv_3b(combined_tensor) + combined_tensor = self.enc_conv_3a(combined_tensor) combined_tensor = self.enc_flat(combined_tensor) combined_tensor = self.enc_lin_1(combined_tensor) - combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) combined_tensor = self.enc_lin_2(combined_tensor) - combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) # @@ -265,15 +258,14 @@ class CNNRouteGeneratorModel(LightningBaseModule): def generate(self, z): alt_tensor = self.gen_lin_1(z) alt_tensor = self.activation(alt_tensor) - # alt_tensor = self.gen_lin_2(alt_tensor) - # alt_tensor = self.activation(alt_tensor) + alt_tensor = self.reshape_to_last_conv(alt_tensor) alt_tensor = self.gen_deconv_1a(alt_tensor) alt_tensor = self.gen_deconv_2a(alt_tensor) - # alt_tensor = self.gen_deconv_3a(alt_tensor) - # alt_tensor = self.gen_deconv_3b(alt_tensor) + alt_tensor = self.gen_deconv_3a(alt_tensor) + alt_tensor = self.gen_deconv_out(alt_tensor) # alt_tensor = self.activation(alt_tensor) # alt_tensor = self.sigmoid(alt_tensor) diff --git a/main.py b/main.py index 5b20fe4..c7af73c 100644 --- a/main.py +++ b/main.py @@ -39,7 +39,7 @@ main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--data_mode", type=str, default='ae_no_label_in_map', help="") +main_arg_parser.add_argument("--data_mode", type=str, default='vae_no_label_in_map', help="") # Transformations main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") @@ -55,7 +55,7 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0 # Model main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="") -main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 32]", help="") +main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")