Mnist Test
This commit is contained in:
parent
2305c8e54a
commit
defa232bf2
@ -40,7 +40,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
|
||||
loss = kld_loss + element_wise_loss
|
||||
else:
|
||||
@ -131,7 +131,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.activation = nn.ReLU()
|
||||
self.activation = nn.LeakyReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
#
|
||||
@ -149,12 +149,8 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
@ -162,12 +158,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
|
||||
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
last_conv_shape = self.enc_conv_2b.shape
|
||||
last_conv_shape = self.enc_conv_3a.shape
|
||||
self.enc_flat = Flatten(last_conv_shape)
|
||||
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
||||
|
||||
@ -196,14 +193,18 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
|
||||
|
||||
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=1, conv_kernel=9, conv_stride=1,
|
||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=1, conv_kernel=7, conv_stride=1,
|
||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None,
|
||||
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
@ -229,23 +230,15 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
def encode(self, batch_x):
|
||||
combined_tensor = self.enc_conv_0(batch_x)
|
||||
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_1a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_1b(combined_tensor)
|
||||
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_2a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_2b(combined_tensor)
|
||||
# combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
|
||||
# combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
# combined_tensor = self.enc_conv_3b(combined_tensor)
|
||||
combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_flat(combined_tensor)
|
||||
combined_tensor = self.enc_lin_1(combined_tensor)
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
#
|
||||
@ -265,15 +258,14 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def generate(self, z):
|
||||
alt_tensor = self.gen_lin_1(z)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
# alt_tensor = self.gen_lin_2(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
|
||||
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
||||
|
||||
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
||||
|
||||
# alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
# alt_tensor = self.gen_deconv_3b(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
|
||||
alt_tensor = self.gen_deconv_out(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
# alt_tensor = self.sigmoid(alt_tensor)
|
||||
|
4
main.py
4
main.py
@ -39,7 +39,7 @@ main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes',
|
||||
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
|
||||
|
||||
main_arg_parser.add_argument("--data_mode", type=str, default='ae_no_label_in_map', help="")
|
||||
main_arg_parser.add_argument("--data_mode", type=str, default='vae_no_label_in_map', help="")
|
||||
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
|
||||
@ -55,7 +55,7 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 32]", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
|
||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||
|
Loading…
x
Reference in New Issue
Block a user