VAE Debugged and Running

This commit is contained in:
Si11ium 2020-03-25 09:39:59 +01:00
parent defa232bf2
commit 934dadb558
5 changed files with 171 additions and 193 deletions

View File

@ -1,5 +1,7 @@
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
import numpy as np
import torch
class MyMNIST(MNIST):
@ -9,12 +11,12 @@ class MyMNIST(MNIST):
return np.asarray(self.test_dataset[0][0]).shape
def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__('res', train=False, download=True)
super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
pass
def __getitem__(self, item):
image = super(MyMNIST, self).__getitem__(item)
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
image, label = super(MyMNIST, self).__getitem__(item)
return image, label
@property
def train_dataset(self):

View File

@ -1,17 +1,16 @@
from functools import reduce
from operator import mul
from random import choices, choice
from random import choice
import torch
from torch import nn
from torch.optim import Adam
from torchvision.datasets import MNIST
from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.blocks import ConvModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
@ -20,55 +19,33 @@ from lib.visualization.generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule):
torch.autograd.set_detect_anomaly(True)
name = 'CNNRouteGenerator'
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, target = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
target = batch_x if 'ae' in self.hparams.data_param.mode else target
element_wise_loss = self.criterion(generated_alternative, target)
batch_x, _ = batch_xy
reconstruction, z, mu, logvar = self(batch_x)
if 'vae' in self.hparams.data_param.mode:
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
recon_loss = self.criterion(reconstruction, batch_x)
loss = kld_loss + element_wise_loss
else:
loss = element_wise_loss
kld_loss = 0
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kldivergence
return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence))
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, _ = batch_xy
if 'vae' in self.hparams.data_param.mode:
z, mu, logvar = self.encode(batch_x)
else:
z = self.encode(batch_x)
mu, logvar = z, z
generated_alternative = self.generate(mu)
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar)
mu, logvar = self.encoder(batch_x)
z = self.reparameterize(mu, logvar)
if 'hom' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
elif 'alt' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
elif 'vae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY)
elif 'ae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY)
else:
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
reconstruction = self.decoder(mu)
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar)
labels = torch.full((batch_x.shape[0], 1), V.ANY)
return_dict.update(labels=self._move_to_model_device(labels))
return return_dict
@ -87,12 +64,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
return dict(epoch=self.current_epoch)
def on_epoch_start(self):
# self.dataset.seed(self.logger.version)
# torch.random.manual_seed(self.logger.version)
# np.random.seed(self.logger.version)
pass
def validation_step(self, *args):
return self._test_val_step(*args)
@ -113,160 +84,163 @@ class CNNRouteGeneratorModel(LightningBaseModule):
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length, normalized=True)
self.criterion = nn.MSELoss()
length=self.hparams.data_param.dataset_length)
self.criterion = nn.BCELoss(reduction='sum')
self.dataset = MyMNIST()
# Additional Attributes #
#######################################################
# Additional Attributes
###################################################
self.in_shape = self.dataset.map_shapes_max
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
########################################################
# NN Nodes
###################################################
self.encoder = Encoder(self.in_shape, self.hparams)
self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams)
def forward(self, batch_x):
# Encode
mu, logvar = self.encoder(batch_x)
# Bottleneck
z = self.reparameterize(mu, logvar)
# Decode
reconstruction = self.decoder(z)
return reconstruction, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = 0.5 * torch.exp(logvar)
eps = torch.randn_like(mu)
z = mu + std * eps
return z
class Encoder(nn.Module):
def __init__(self, in_shape, hparams):
super(Encoder, self).__init__()
# Params
###################################################
self.hparams = hparams
# Additional Attributes
###################################################
self.in_shape = in_shape
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim * 10
# NN Nodes
###################################################
#
# Utils
self.activation = nn.LeakyReLU()
self.sigmoid = nn.Sigmoid()
self.activation = self.hparams.activation()
#
# Map Encoder
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
# Encoder
self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
last_conv_shape = self.enc_conv_3a.shape
self.enc_flat = Flatten(last_conv_shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
#
# Mixed Encoder
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
self.last_conv_shape = self.conv_3.shape
self.flat = Flatten(in_shape=self.last_conv_shape)
self.lin = nn.Linear(self.flat.shape, self.feature_dim)
#
# Variational Bottleneck
if 'vae' in self.hparams.data_param.mode:
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
#
# Linear Bottleneck
else:
self.z = nn.Linear(self.feature_dim, self.lat_dim)
#
# Alternative Generator
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
def forward(self, batch_x):
#
# Encode
if 'vae' in self.hparams.data_param.mode:
z, mu, logvar = self.encode(batch_x)
else:
z = self.encode(batch_x)
mu, logvar = z, z
tensor = self.conv_0(batch_x)
tensor = self.conv_1(tensor)
tensor = self.conv_2(tensor)
tensor = self.conv_3(tensor)
#
# Generate
alt_tensor = self.generate(z)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def encode(self, batch_x):
combined_tensor = self.enc_conv_0(batch_x)
combined_tensor = self.enc_conv_1a(combined_tensor)
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_3a(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.activation(combined_tensor)
tensor = self.flat(tensor)
tensor = self.lin(tensor)
tensor = self.activation(tensor)
#
# Variational
# Parameter and Sampling
if 'vae' in self.hparams.data_param.mode:
mu = self.mu(combined_tensor)
logvar = self.logvar(combined_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
else:
#
# Linear Bottleneck
z = self.z(combined_tensor)
return z
# Parameter for Sampling
mu = self.mu(tensor)
logvar = self.logvar(tensor)
return mu, logvar
def generate(self, z):
alt_tensor = self.gen_lin_1(z)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_last_conv(alt_tensor)
alt_tensor = self.gen_deconv_1a(alt_tensor)
class Decoder(nn.Module):
alt_tensor = self.gen_deconv_2a(alt_tensor)
def __init__(self, out_channels, last_conv_shape, hparams):
super(Decoder, self).__init__()
# Params
###################################################
self.hparams = hparams
alt_tensor = self.gen_deconv_3a(alt_tensor)
# Additional Attributes
###################################################
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim
self.out_channels = out_channels
alt_tensor = self.gen_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
# NN Nodes
###################################################
#
# Utils
self.activation = self.hparams.activation()
#
# Alternative Generator
self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape))
self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape)
self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, z):
tensor = self.lin(z)
tensor = self.activation(tensor)
tensor = self.reshape(tensor)
tensor = self.deconv_1(tensor)
tensor = self.deconv_2(tensor)
tensor = self.deconv_3(tensor)
reconstruction = self.deconv_out(tensor)
return reconstruction

View File

@ -17,9 +17,9 @@ class ConvModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False,
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__()
# Module Parameters
@ -30,12 +30,14 @@ class ConvModule(nn.Module):
# Convolution Parameters
self.padding = conv_padding
self.stride = conv_stride
self.conv_filters = conv_filters
self.conv_kernel = conv_kernel
# Modules
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride
)
@ -57,22 +59,23 @@ class DeConvModule(nn.Module):
output = self(x)
return output.shape[1:]
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=False,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=0,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
use_bias=True, use_norm=False):
super(DeConvModule, self).__init__()
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.padding = conv_padding
self.conv_kernel = conv_kernel
self.stride = conv_stride
self.in_shape = in_shape
self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride)
self.activation = activation() if activation else lambda x: x

View File

@ -6,8 +6,6 @@ from torch import nn
from torch import functional as F
from torch.utils.data import DataLoader
from lib.objects.map import MapStorage
import pytorch_lightning as pl
@ -27,10 +25,11 @@ class Flatten(nn.Module):
print(e)
return -1
def __init__(self, in_shape, to=(-1, )):
def __init__(self, in_shape, to=-1):
assert isinstance(to, int) or isinstance(to, tuple)
super(Flatten, self).__init__()
self.in_shape = in_shape
self.to = to
self.to = (to,) if isinstance(to, int) else to
def forward(self, x):
return x.view(x.size(0), *self.to)
@ -107,7 +106,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading
# =============================================================================
# Map Object
self.map_storage = MapStorage(self.hparams.data_param.map_root)
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
def size(self):
return self.shape

12
main.py
View File

@ -47,19 +47,19 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
# Transformations
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")