From 934dadb5589bc9b20d481184ab5eb33c6bb66845 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Wed, 25 Mar 2020 09:39:59 +0100 Subject: [PATCH] VAE Debugged and Running --- datasets/mnist.py | 8 +- lib/models/generators/cnn.py | 312 ++++++++++++++++------------------- lib/modules/blocks.py | 23 +-- lib/modules/utils.py | 9 +- main.py | 12 +- 5 files changed, 171 insertions(+), 193 deletions(-) diff --git a/datasets/mnist.py b/datasets/mnist.py index 66fe06f..7a1776e 100644 --- a/datasets/mnist.py +++ b/datasets/mnist.py @@ -1,5 +1,7 @@ from torchvision.datasets import MNIST +from torchvision.transforms import transforms import numpy as np +import torch class MyMNIST(MNIST): @@ -9,12 +11,12 @@ class MyMNIST(MNIST): return np.asarray(self.test_dataset[0][0]).shape def __init__(self, *args, **kwargs): - super(MyMNIST, self).__init__('res', train=False, download=True) + super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor()) pass def __getitem__(self, item): - image = super(MyMNIST, self).__getitem__(item) - return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1] + image, label = super(MyMNIST, self).__getitem__(item) + return image, label @property def train_dataset(self): diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index c2897e5..c86c809 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -1,17 +1,16 @@ from functools import reduce from operator import mul -from random import choices, choice +from random import choice import torch from torch import nn from torch.optim import Adam -from torchvision.datasets import MNIST from datasets.mnist import MyMNIST from datasets.trajectory_dataset import TrajData -from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule +from lib.modules.blocks import ConvModule, DeConvModule from lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt @@ -20,55 +19,33 @@ from lib.visualization.generator_eval import GeneratorVisualizer class CNNRouteGeneratorModel(LightningBaseModule): - + torch.autograd.set_detect_anomaly(True) name = 'CNNRouteGenerator' def configure_optimizers(self): return Adam(self.parameters(), lr=self.hparams.train_param.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): - batch_x, target = batch_xy - generated_alternative, z, mu, logvar = self(batch_x) - target = batch_x if 'ae' in self.hparams.data_param.mode else target - element_wise_loss = self.criterion(generated_alternative, target) + batch_x, _ = batch_xy + reconstruction, z, mu, logvar = self(batch_x) - if 'vae' in self.hparams.data_param.mode: - # see Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - # Dimensional Resizing TODO: Does This make sense? Sanity Check it! - # kld_loss /= reduce(mul, self.in_shape) - kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size + recon_loss = self.criterion(reconstruction, batch_x) - loss = kld_loss + element_wise_loss - else: - loss = element_wise_loss - kld_loss = 0 - return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) + kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + + loss = recon_loss + kldivergence + return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence)) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, _ = batch_xy - if 'vae' in self.hparams.data_param.mode: - z, mu, logvar = self.encode(batch_x) - else: - z = self.encode(batch_x) - mu, logvar = z, z - generated_alternative = self.generate(mu) - return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar) + mu, logvar = self.encoder(batch_x) + z = self.reparameterize(mu, logvar) - if 'hom' in self.hparams.data_param.mode: - labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC) - elif 'alt' in self.hparams.data_param.mode: - labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE) - elif 'vae' in self.hparams.data_param.mode: - labels = torch.full((batch_x.shape[0], 1), V.ANY) - elif 'ae' in self.hparams.data_param.mode: - labels = torch.full((batch_x.shape[0], 1), V.ANY) - else: - labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values + reconstruction = self.decoder(mu) + return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar) + + labels = torch.full((batch_x.shape[0], 1), V.ANY) return_dict.update(labels=self._move_to_model_device(labels)) return return_dict @@ -87,12 +64,6 @@ class CNNRouteGeneratorModel(LightningBaseModule): return dict(epoch=self.current_epoch) - def on_epoch_start(self): - # self.dataset.seed(self.logger.version) - # torch.random.manual_seed(self.logger.version) - # np.random.seed(self.logger.version) - pass - def validation_step(self, *args): return self._test_val_step(*args) @@ -113,160 +84,163 @@ class CNNRouteGeneratorModel(LightningBaseModule): self.dataset = TrajData(self.hparams.data_param.map_root, mode=self.hparams.data_param.mode, preprocessed=self.hparams.data_param.use_preprocessed, - length=self.hparams.data_param.dataset_length, normalized=True) - self.criterion = nn.MSELoss() + length=self.hparams.data_param.dataset_length) + self.criterion = nn.BCELoss(reduction='sum') self.dataset = MyMNIST() - # Additional Attributes # - ####################################################### + # Additional Attributes + ################################################### self.in_shape = self.dataset.map_shapes_max self.use_res_net = self.hparams.model_param.use_res_net self.lat_dim = self.hparams.model_param.lat_dim self.feature_dim = self.lat_dim self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0] - ######################################################## + + # NN Nodes + ################################################### + self.encoder = Encoder(self.in_shape, self.hparams) + self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams) + + def forward(self, batch_x): + # Encode + mu, logvar = self.encoder(batch_x) + + # Bottleneck + z = self.reparameterize(mu, logvar) + + # Decode + reconstruction = self.decoder(z) + return reconstruction, z, mu, logvar + + @staticmethod + def reparameterize(mu, logvar): + std = 0.5 * torch.exp(logvar) + eps = torch.randn_like(mu) + z = mu + std * eps + return z + + +class Encoder(nn.Module): + def __init__(self, in_shape, hparams): + super(Encoder, self).__init__() + # Params + ################################################### + self.hparams = hparams + + # Additional Attributes + ################################################### + self.in_shape = in_shape + self.use_res_net = self.hparams.model_param.use_res_net + self.lat_dim = self.hparams.model_param.lat_dim + self.feature_dim = self.lat_dim * 10 # NN Nodes ################################################### # # Utils - self.activation = nn.LeakyReLU() - self.sigmoid = nn.Sigmoid() + self.activation = self.hparams.activation() # - # Map Encoder - self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, - conv_filters=self.hparams.model_param.filters[0], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) + # Encoder + self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, + conv_filters=self.hparams.model_param.filters[0], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) - self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, - conv_padding=2, conv_filters=self.hparams.model_param.filters[0], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[1], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) + self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[0], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) - self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, - conv_padding=2, conv_filters=self.hparams.model_param.filters[1], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[2], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) + self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[1], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) - self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[2], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) + self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) - last_conv_shape = self.enc_conv_3a.shape - self.enc_flat = Flatten(last_conv_shape) - self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim) - - # - # Mixed Encoder - self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim) - self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x + self.last_conv_shape = self.conv_3.shape + self.flat = Flatten(in_shape=self.last_conv_shape) + self.lin = nn.Linear(self.flat.shape, self.feature_dim) # # Variational Bottleneck - if 'vae' in self.hparams.data_param.mode: - self.mu = nn.Linear(self.feature_dim, self.lat_dim) - self.logvar = nn.Linear(self.feature_dim, self.lat_dim) - - # - # Linear Bottleneck - else: - self.z = nn.Linear(self.feature_dim, self.lat_dim) - - # - # Alternative Generator - self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape) - - # self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape) - - self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape) - - self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2], - conv_padding=0, conv_kernel=7, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - - self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1], - conv_padding=1, conv_kernel=5, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - - self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0], - conv_padding=0, conv_kernel=3, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - - self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None, - conv_padding=0, conv_kernel=3, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) + self.mu = nn.Linear(self.feature_dim, self.lat_dim) + self.logvar = nn.Linear(self.feature_dim, self.lat_dim) def forward(self, batch_x): - # - # Encode - if 'vae' in self.hparams.data_param.mode: - z, mu, logvar = self.encode(batch_x) - else: - z = self.encode(batch_x) - mu, logvar = z, z + tensor = self.conv_0(batch_x) + tensor = self.conv_1(tensor) + tensor = self.conv_2(tensor) + tensor = self.conv_3(tensor) - # - # Generate - alt_tensor = self.generate(z) - return alt_tensor, z, mu, logvar - - @staticmethod - def reparameterize(mu, logvar): - std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - - def encode(self, batch_x): - combined_tensor = self.enc_conv_0(batch_x) - combined_tensor = self.enc_conv_1a(combined_tensor) - combined_tensor = self.enc_conv_2a(combined_tensor) - combined_tensor = self.enc_conv_3a(combined_tensor) - - combined_tensor = self.enc_flat(combined_tensor) - combined_tensor = self.enc_lin_1(combined_tensor) - combined_tensor = self.activation(combined_tensor) - - combined_tensor = self.enc_lin_2(combined_tensor) - combined_tensor = self.activation(combined_tensor) + tensor = self.flat(tensor) + tensor = self.lin(tensor) + tensor = self.activation(tensor) # # Variational - # Parameter and Sampling - if 'vae' in self.hparams.data_param.mode: - mu = self.mu(combined_tensor) - logvar = self.logvar(combined_tensor) - z = self.reparameterize(mu, logvar) - return z, mu, logvar - else: - # - # Linear Bottleneck - z = self.z(combined_tensor) - return z + # Parameter for Sampling + mu = self.mu(tensor) + logvar = self.logvar(tensor) + return mu, logvar - def generate(self, z): - alt_tensor = self.gen_lin_1(z) - alt_tensor = self.activation(alt_tensor) - alt_tensor = self.reshape_to_last_conv(alt_tensor) - alt_tensor = self.gen_deconv_1a(alt_tensor) +class Decoder(nn.Module): - alt_tensor = self.gen_deconv_2a(alt_tensor) + def __init__(self, out_channels, last_conv_shape, hparams): + super(Decoder, self).__init__() + # Params + ################################################### + self.hparams = hparams - alt_tensor = self.gen_deconv_3a(alt_tensor) + # Additional Attributes + ################################################### + self.use_res_net = self.hparams.model_param.use_res_net + self.lat_dim = self.hparams.model_param.lat_dim + self.feature_dim = self.lat_dim + self.out_channels = out_channels - alt_tensor = self.gen_deconv_out(alt_tensor) - # alt_tensor = self.activation(alt_tensor) - # alt_tensor = self.sigmoid(alt_tensor) - return alt_tensor + # NN Nodes + ################################################### + # + # Utils + self.activation = self.hparams.activation() + + # + # Alternative Generator + self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape)) + + self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape) + + self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2], + conv_padding=0, conv_kernel=7, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + + self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1], + conv_padding=1, conv_kernel=5, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + + self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0], + conv_padding=0, conv_kernel=3, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + + self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid, + conv_padding=0, conv_kernel=3, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + + def forward(self, z): + + tensor = self.lin(z) + tensor = self.activation(tensor) + tensor = self.reshape(tensor) + + tensor = self.deconv_1(tensor) + tensor = self.deconv_2(tensor) + tensor = self.deconv_3(tensor) + reconstruction = self.deconv_out(tensor) + return reconstruction diff --git a/lib/modules/blocks.py b/lib/modules/blocks.py index a1e81bd..2713906 100644 --- a/lib/modules/blocks.py +++ b/lib/modules/blocks.py @@ -17,9 +17,9 @@ class ConvModule(nn.Module): output = self(x) return output.shape[1:] - def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False, - dropout: Union[int, float] = 0, conv_class=nn.Conv2d, - conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0): + def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None, + use_bias=True, use_norm=False, dropout: Union[int, float] = 0, + conv_class=nn.Conv2d, conv_stride=1, conv_padding=0): super(ConvModule, self).__init__() # Module Parameters @@ -30,12 +30,14 @@ class ConvModule(nn.Module): # Convolution Parameters self.padding = conv_padding self.stride = conv_stride + self.conv_filters = conv_filters + self.conv_kernel = conv_kernel # Modules self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x - self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x - self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias, + self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x + self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias, padding=self.padding, stride=self.stride ) @@ -57,22 +59,23 @@ class DeConvModule(nn.Module): output = self(x) return output.shape[1:] - def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0, - dropout: Union[int, float] = 0, autopad=False, - activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None, + def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0, + dropout: Union[int, float] = 0, autopad=0, + activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0, use_bias=True, use_norm=False): super(DeConvModule, self).__init__() in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] self.padding = conv_padding + self.conv_kernel = conv_kernel self.stride = conv_stride self.in_shape = in_shape self.conv_filters = conv_filters self.autopad = AutoPad() if autopad else lambda x: x self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x - self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x + self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x - self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias, + self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias, padding=self.padding, stride=self.stride) self.activation = activation() if activation else lambda x: x diff --git a/lib/modules/utils.py b/lib/modules/utils.py index 53d86dd..6fe0338 100644 --- a/lib/modules/utils.py +++ b/lib/modules/utils.py @@ -6,8 +6,6 @@ from torch import nn from torch import functional as F from torch.utils.data import DataLoader -from lib.objects.map import MapStorage - import pytorch_lightning as pl @@ -27,10 +25,11 @@ class Flatten(nn.Module): print(e) return -1 - def __init__(self, in_shape, to=(-1, )): + def __init__(self, in_shape, to=-1): + assert isinstance(to, int) or isinstance(to, tuple) super(Flatten, self).__init__() self.in_shape = in_shape - self.to = to + self.to = (to,) if isinstance(to, int) else to def forward(self, x): return x.view(x.size(0), *self.to) @@ -107,7 +106,7 @@ class LightningBaseModule(pl.LightningModule, ABC): # Data loading # ============================================================================= # Map Object - self.map_storage = MapStorage(self.hparams.data_param.map_root) + # self.map_storage = MapStorage(self.hparams.data_param.map_root) def size(self): return self.shape diff --git a/main.py b/main.py index c7af73c..d2a0db9 100644 --- a/main.py +++ b/main.py @@ -47,19 +47,19 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa # Transformations main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") -main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="") -main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="") -main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") +main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="") +main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="") +main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") # Model main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") -main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="") +main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") -main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="") +main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") -main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="") main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")