Mnist Test

This commit is contained in:
Steffen Illium 2020-03-15 21:35:18 +01:00
parent 2305c8e54a
commit defa232bf2
2 changed files with 19 additions and 27 deletions

View File

@ -40,7 +40,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
loss = kld_loss + element_wise_loss
else:
@ -131,7 +131,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
###################################################
#
# Utils
self.activation = nn.ReLU()
self.activation = nn.LeakyReLU()
self.sigmoid = nn.Sigmoid()
#
@ -149,12 +149,8 @@ class CNNRouteGeneratorModel(LightningBaseModule):
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
@ -162,12 +158,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
last_conv_shape = self.enc_conv_2b.shape
last_conv_shape = self.enc_conv_3a.shape
self.enc_flat = Flatten(last_conv_shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
@ -196,14 +193,18 @@ class CNNRouteGeneratorModel(LightningBaseModule):
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=1, conv_kernel=9, conv_stride=1,
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=7, conv_stride=1,
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None,
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
@ -229,23 +230,15 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def encode(self, batch_x):
combined_tensor = self.enc_conv_0(batch_x)
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_1a(combined_tensor)
combined_tensor = self.enc_conv_1b(combined_tensor)
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_2b(combined_tensor)
# combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
# combined_tensor = self.enc_conv_3a(combined_tensor)
# combined_tensor = self.enc_conv_3b(combined_tensor)
combined_tensor = self.enc_conv_3a(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
#
@ -265,15 +258,14 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def generate(self, z):
alt_tensor = self.gen_lin_1(z)
alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.gen_lin_2(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_last_conv(alt_tensor)
alt_tensor = self.gen_deconv_1a(alt_tensor)
alt_tensor = self.gen_deconv_2a(alt_tensor)
# alt_tensor = self.gen_deconv_3a(alt_tensor)
# alt_tensor = self.gen_deconv_3b(alt_tensor)
alt_tensor = self.gen_deconv_3a(alt_tensor)
alt_tensor = self.gen_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.sigmoid(alt_tensor)

View File

@ -39,7 +39,7 @@ main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes',
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_mode", type=str, default='ae_no_label_in_map', help="")
main_arg_parser.add_argument("--data_mode", type=str, default='vae_no_label_in_map', help="")
# Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
@ -55,7 +55,7 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0
# Model
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 32]", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")