Mnist Test
This commit is contained in:
@ -40,7 +40,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
|
||||
loss = kld_loss + element_wise_loss
|
||||
else:
|
||||
@ -131,7 +131,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.activation = nn.ReLU()
|
||||
self.activation = nn.LeakyReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
#
|
||||
@ -149,12 +149,8 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
@ -162,12 +158,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
|
||||
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
last_conv_shape = self.enc_conv_2b.shape
|
||||
last_conv_shape = self.enc_conv_3a.shape
|
||||
self.enc_flat = Flatten(last_conv_shape)
|
||||
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
||||
|
||||
@ -196,14 +193,18 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
|
||||
|
||||
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=1, conv_kernel=9, conv_stride=1,
|
||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=1, conv_kernel=7, conv_stride=1,
|
||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None,
|
||||
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
@ -229,23 +230,15 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
def encode(self, batch_x):
|
||||
combined_tensor = self.enc_conv_0(batch_x)
|
||||
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_1a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_1b(combined_tensor)
|
||||
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_2a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_2b(combined_tensor)
|
||||
# combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
|
||||
# combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
# combined_tensor = self.enc_conv_3b(combined_tensor)
|
||||
combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_flat(combined_tensor)
|
||||
combined_tensor = self.enc_lin_1(combined_tensor)
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
#
|
||||
@ -265,15 +258,14 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def generate(self, z):
|
||||
alt_tensor = self.gen_lin_1(z)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
# alt_tensor = self.gen_lin_2(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
|
||||
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
||||
|
||||
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
||||
|
||||
# alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
# alt_tensor = self.gen_deconv_3b(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
|
||||
alt_tensor = self.gen_deconv_out(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
# alt_tensor = self.sigmoid(alt_tensor)
|
||||
|
Reference in New Issue
Block a user