Mnist Test

This commit is contained in:
Steffen Illium
2020-03-15 21:35:18 +01:00
parent 2305c8e54a
commit defa232bf2
2 changed files with 19 additions and 27 deletions

View File

@ -40,7 +40,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
loss = kld_loss + element_wise_loss
else:
@ -131,7 +131,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
###################################################
#
# Utils
self.activation = nn.ReLU()
self.activation = nn.LeakyReLU()
self.sigmoid = nn.Sigmoid()
#
@ -149,12 +149,8 @@ class CNNRouteGeneratorModel(LightningBaseModule):
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
@ -162,12 +158,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
last_conv_shape = self.enc_conv_2b.shape
last_conv_shape = self.enc_conv_3a.shape
self.enc_flat = Flatten(last_conv_shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
@ -196,14 +193,18 @@ class CNNRouteGeneratorModel(LightningBaseModule):
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=1, conv_kernel=9, conv_stride=1,
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=7, conv_stride=1,
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None,
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
@ -229,23 +230,15 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def encode(self, batch_x):
combined_tensor = self.enc_conv_0(batch_x)
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_1a(combined_tensor)
combined_tensor = self.enc_conv_1b(combined_tensor)
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_2b(combined_tensor)
# combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
# combined_tensor = self.enc_conv_3a(combined_tensor)
# combined_tensor = self.enc_conv_3b(combined_tensor)
combined_tensor = self.enc_conv_3a(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
#
@ -265,15 +258,14 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def generate(self, z):
alt_tensor = self.gen_lin_1(z)
alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.gen_lin_2(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_last_conv(alt_tensor)
alt_tensor = self.gen_deconv_1a(alt_tensor)
alt_tensor = self.gen_deconv_2a(alt_tensor)
# alt_tensor = self.gen_deconv_3a(alt_tensor)
# alt_tensor = self.gen_deconv_3b(alt_tensor)
alt_tensor = self.gen_deconv_3a(alt_tensor)
alt_tensor = self.gen_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.sigmoid(alt_tensor)