VAE Debugged and Running
This commit is contained in:
parent
defa232bf2
commit
934dadb558
@ -1,5 +1,7 @@
|
|||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision.transforms import transforms
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class MyMNIST(MNIST):
|
class MyMNIST(MNIST):
|
||||||
@ -9,12 +11,12 @@ class MyMNIST(MNIST):
|
|||||||
return np.asarray(self.test_dataset[0][0]).shape
|
return np.asarray(self.test_dataset[0][0]).shape
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(MyMNIST, self).__init__('res', train=False, download=True)
|
super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
image = super(MyMNIST, self).__getitem__(item)
|
image, label = super(MyMNIST, self).__getitem__(item)
|
||||||
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
|
return image, label
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_dataset(self):
|
def train_dataset(self):
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
from operator import mul
|
from operator import mul
|
||||||
|
|
||||||
from random import choices, choice
|
from random import choice
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torchvision.datasets import MNIST
|
|
||||||
|
|
||||||
from datasets.mnist import MyMNIST
|
from datasets.mnist import MyMNIST
|
||||||
from datasets.trajectory_dataset import TrajData
|
from datasets.trajectory_dataset import TrajData
|
||||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
from lib.modules.blocks import ConvModule, DeConvModule
|
||||||
from lib.modules.utils import LightningBaseModule, Flatten
|
from lib.modules.utils import LightningBaseModule, Flatten
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -20,55 +19,33 @@ from lib.visualization.generator_eval import GeneratorVisualizer
|
|||||||
|
|
||||||
|
|
||||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||||
|
torch.autograd.set_detect_anomaly(True)
|
||||||
name = 'CNNRouteGenerator'
|
name = 'CNNRouteGenerator'
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||||
|
|
||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
batch_x, target = batch_xy
|
batch_x, _ = batch_xy
|
||||||
generated_alternative, z, mu, logvar = self(batch_x)
|
reconstruction, z, mu, logvar = self(batch_x)
|
||||||
target = batch_x if 'ae' in self.hparams.data_param.mode else target
|
|
||||||
element_wise_loss = self.criterion(generated_alternative, target)
|
|
||||||
|
|
||||||
if 'vae' in self.hparams.data_param.mode:
|
recon_loss = self.criterion(reconstruction, batch_x)
|
||||||
# see Appendix B from VAE paper:
|
|
||||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
|
||||||
# https://arxiv.org/abs/1312.6114
|
|
||||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
|
||||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
|
||||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
|
||||||
# kld_loss /= reduce(mul, self.in_shape)
|
|
||||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
|
||||||
|
|
||||||
loss = kld_loss + element_wise_loss
|
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
else:
|
|
||||||
loss = element_wise_loss
|
loss = recon_loss + kldivergence
|
||||||
kld_loss = 0
|
return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence))
|
||||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
|
||||||
|
|
||||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||||
batch_x, _ = batch_xy
|
batch_x, _ = batch_xy
|
||||||
if 'vae' in self.hparams.data_param.mode:
|
|
||||||
z, mu, logvar = self.encode(batch_x)
|
|
||||||
else:
|
|
||||||
z = self.encode(batch_x)
|
|
||||||
mu, logvar = z, z
|
|
||||||
|
|
||||||
generated_alternative = self.generate(mu)
|
mu, logvar = self.encoder(batch_x)
|
||||||
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar)
|
z = self.reparameterize(mu, logvar)
|
||||||
|
|
||||||
if 'hom' in self.hparams.data_param.mode:
|
reconstruction = self.decoder(mu)
|
||||||
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
|
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar)
|
||||||
elif 'alt' in self.hparams.data_param.mode:
|
|
||||||
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
|
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
||||||
elif 'vae' in self.hparams.data_param.mode:
|
|
||||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
|
||||||
elif 'ae' in self.hparams.data_param.mode:
|
|
||||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
|
||||||
else:
|
|
||||||
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
|
|
||||||
|
|
||||||
return_dict.update(labels=self._move_to_model_device(labels))
|
return_dict.update(labels=self._move_to_model_device(labels))
|
||||||
return return_dict
|
return return_dict
|
||||||
@ -87,12 +64,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
|
|
||||||
return dict(epoch=self.current_epoch)
|
return dict(epoch=self.current_epoch)
|
||||||
|
|
||||||
def on_epoch_start(self):
|
|
||||||
# self.dataset.seed(self.logger.version)
|
|
||||||
# torch.random.manual_seed(self.logger.version)
|
|
||||||
# np.random.seed(self.logger.version)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def validation_step(self, *args):
|
def validation_step(self, *args):
|
||||||
return self._test_val_step(*args)
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
@ -113,160 +84,163 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
self.dataset = TrajData(self.hparams.data_param.map_root,
|
self.dataset = TrajData(self.hparams.data_param.map_root,
|
||||||
mode=self.hparams.data_param.mode,
|
mode=self.hparams.data_param.mode,
|
||||||
preprocessed=self.hparams.data_param.use_preprocessed,
|
preprocessed=self.hparams.data_param.use_preprocessed,
|
||||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
length=self.hparams.data_param.dataset_length)
|
||||||
self.criterion = nn.MSELoss()
|
self.criterion = nn.BCELoss(reduction='sum')
|
||||||
|
|
||||||
self.dataset = MyMNIST()
|
self.dataset = MyMNIST()
|
||||||
|
|
||||||
# Additional Attributes #
|
# Additional Attributes
|
||||||
#######################################################
|
###################################################
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
self.use_res_net = self.hparams.model_param.use_res_net
|
self.use_res_net = self.hparams.model_param.use_res_net
|
||||||
self.lat_dim = self.hparams.model_param.lat_dim
|
self.lat_dim = self.hparams.model_param.lat_dim
|
||||||
self.feature_dim = self.lat_dim
|
self.feature_dim = self.lat_dim
|
||||||
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
|
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
|
||||||
########################################################
|
|
||||||
|
# NN Nodes
|
||||||
|
###################################################
|
||||||
|
self.encoder = Encoder(self.in_shape, self.hparams)
|
||||||
|
self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams)
|
||||||
|
|
||||||
|
def forward(self, batch_x):
|
||||||
|
# Encode
|
||||||
|
mu, logvar = self.encoder(batch_x)
|
||||||
|
|
||||||
|
# Bottleneck
|
||||||
|
z = self.reparameterize(mu, logvar)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
reconstruction = self.decoder(z)
|
||||||
|
return reconstruction, z, mu, logvar
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reparameterize(mu, logvar):
|
||||||
|
std = 0.5 * torch.exp(logvar)
|
||||||
|
eps = torch.randn_like(mu)
|
||||||
|
z = mu + std * eps
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, in_shape, hparams):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
# Params
|
||||||
|
###################################################
|
||||||
|
self.hparams = hparams
|
||||||
|
|
||||||
|
# Additional Attributes
|
||||||
|
###################################################
|
||||||
|
self.in_shape = in_shape
|
||||||
|
self.use_res_net = self.hparams.model_param.use_res_net
|
||||||
|
self.lat_dim = self.hparams.model_param.lat_dim
|
||||||
|
self.feature_dim = self.lat_dim * 10
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
###################################################
|
###################################################
|
||||||
#
|
#
|
||||||
# Utils
|
# Utils
|
||||||
self.activation = nn.LeakyReLU()
|
self.activation = self.hparams.activation()
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Map Encoder
|
# Encoder
|
||||||
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||||
conv_filters=self.hparams.model_param.filters[0],
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
conv_filters=self.hparams.model_param.filters[0],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
|
||||||
conv_filters=self.hparams.model_param.filters[1],
|
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
|
||||||
|
|
||||||
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
conv_filters=self.hparams.model_param.filters[1],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
|
||||||
conv_filters=self.hparams.model_param.filters[2],
|
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
|
||||||
|
|
||||||
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||||
conv_filters=self.hparams.model_param.filters[2],
|
conv_filters=self.hparams.model_param.filters[2],
|
||||||
use_norm=self.hparams.model_param.use_norm,
|
use_norm=self.hparams.model_param.use_norm,
|
||||||
use_bias=self.hparams.model_param.use_bias)
|
use_bias=self.hparams.model_param.use_bias)
|
||||||
|
|
||||||
last_conv_shape = self.enc_conv_3a.shape
|
self.last_conv_shape = self.conv_3.shape
|
||||||
self.enc_flat = Flatten(last_conv_shape)
|
self.flat = Flatten(in_shape=self.last_conv_shape)
|
||||||
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
self.lin = nn.Linear(self.flat.shape, self.feature_dim)
|
||||||
|
|
||||||
#
|
|
||||||
# Mixed Encoder
|
|
||||||
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
|
|
||||||
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Variational Bottleneck
|
# Variational Bottleneck
|
||||||
if 'vae' in self.hparams.data_param.mode:
|
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||||
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||||
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
|
||||||
|
|
||||||
#
|
|
||||||
# Linear Bottleneck
|
|
||||||
else:
|
|
||||||
self.z = nn.Linear(self.feature_dim, self.lat_dim)
|
|
||||||
|
|
||||||
#
|
|
||||||
# Alternative Generator
|
|
||||||
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
|
|
||||||
|
|
||||||
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
|
||||||
|
|
||||||
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
|
|
||||||
|
|
||||||
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
|
||||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
|
||||||
|
|
||||||
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
|
|
||||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
|
||||||
|
|
||||||
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
|
|
||||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
|
||||||
|
|
||||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
|
|
||||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
|
||||||
use_norm=self.hparams.model_param.use_norm)
|
|
||||||
|
|
||||||
def forward(self, batch_x):
|
def forward(self, batch_x):
|
||||||
#
|
tensor = self.conv_0(batch_x)
|
||||||
# Encode
|
tensor = self.conv_1(tensor)
|
||||||
if 'vae' in self.hparams.data_param.mode:
|
tensor = self.conv_2(tensor)
|
||||||
z, mu, logvar = self.encode(batch_x)
|
tensor = self.conv_3(tensor)
|
||||||
else:
|
|
||||||
z = self.encode(batch_x)
|
|
||||||
mu, logvar = z, z
|
|
||||||
|
|
||||||
#
|
tensor = self.flat(tensor)
|
||||||
# Generate
|
tensor = self.lin(tensor)
|
||||||
alt_tensor = self.generate(z)
|
tensor = self.activation(tensor)
|
||||||
return alt_tensor, z, mu, logvar
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reparameterize(mu, logvar):
|
|
||||||
std = torch.exp(0.5 * logvar)
|
|
||||||
eps = torch.randn_like(std)
|
|
||||||
return mu + eps * std
|
|
||||||
|
|
||||||
def encode(self, batch_x):
|
|
||||||
combined_tensor = self.enc_conv_0(batch_x)
|
|
||||||
combined_tensor = self.enc_conv_1a(combined_tensor)
|
|
||||||
combined_tensor = self.enc_conv_2a(combined_tensor)
|
|
||||||
combined_tensor = self.enc_conv_3a(combined_tensor)
|
|
||||||
|
|
||||||
combined_tensor = self.enc_flat(combined_tensor)
|
|
||||||
combined_tensor = self.enc_lin_1(combined_tensor)
|
|
||||||
combined_tensor = self.activation(combined_tensor)
|
|
||||||
|
|
||||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
|
||||||
combined_tensor = self.activation(combined_tensor)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Variational
|
# Variational
|
||||||
# Parameter and Sampling
|
# Parameter for Sampling
|
||||||
if 'vae' in self.hparams.data_param.mode:
|
mu = self.mu(tensor)
|
||||||
mu = self.mu(combined_tensor)
|
logvar = self.logvar(tensor)
|
||||||
logvar = self.logvar(combined_tensor)
|
return mu, logvar
|
||||||
z = self.reparameterize(mu, logvar)
|
|
||||||
return z, mu, logvar
|
|
||||||
else:
|
|
||||||
#
|
|
||||||
# Linear Bottleneck
|
|
||||||
z = self.z(combined_tensor)
|
|
||||||
return z
|
|
||||||
|
|
||||||
def generate(self, z):
|
|
||||||
alt_tensor = self.gen_lin_1(z)
|
|
||||||
alt_tensor = self.activation(alt_tensor)
|
|
||||||
|
|
||||||
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
class Decoder(nn.Module):
|
||||||
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
|
||||||
|
|
||||||
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
def __init__(self, out_channels, last_conv_shape, hparams):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
# Params
|
||||||
|
###################################################
|
||||||
|
self.hparams = hparams
|
||||||
|
|
||||||
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
# Additional Attributes
|
||||||
|
###################################################
|
||||||
|
self.use_res_net = self.hparams.model_param.use_res_net
|
||||||
|
self.lat_dim = self.hparams.model_param.lat_dim
|
||||||
|
self.feature_dim = self.lat_dim
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
alt_tensor = self.gen_deconv_out(alt_tensor)
|
# NN Nodes
|
||||||
# alt_tensor = self.activation(alt_tensor)
|
###################################################
|
||||||
# alt_tensor = self.sigmoid(alt_tensor)
|
#
|
||||||
return alt_tensor
|
# Utils
|
||||||
|
self.activation = self.hparams.activation()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Alternative Generator
|
||||||
|
self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape))
|
||||||
|
|
||||||
|
self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape)
|
||||||
|
|
||||||
|
self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
||||||
|
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
|
self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1],
|
||||||
|
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
|
self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0],
|
||||||
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
|
self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid,
|
||||||
|
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||||
|
use_norm=self.hparams.model_param.use_norm)
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
|
||||||
|
tensor = self.lin(z)
|
||||||
|
tensor = self.activation(tensor)
|
||||||
|
tensor = self.reshape(tensor)
|
||||||
|
|
||||||
|
tensor = self.deconv_1(tensor)
|
||||||
|
tensor = self.deconv_2(tensor)
|
||||||
|
tensor = self.deconv_3(tensor)
|
||||||
|
reconstruction = self.deconv_out(tensor)
|
||||||
|
return reconstruction
|
||||||
|
@ -17,9 +17,9 @@ class ConvModule(nn.Module):
|
|||||||
output = self(x)
|
output = self(x)
|
||||||
return output.shape[1:]
|
return output.shape[1:]
|
||||||
|
|
||||||
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False,
|
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||||
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
|
use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
|
|
||||||
# Module Parameters
|
# Module Parameters
|
||||||
@ -30,12 +30,14 @@ class ConvModule(nn.Module):
|
|||||||
# Convolution Parameters
|
# Convolution Parameters
|
||||||
self.padding = conv_padding
|
self.padding = conv_padding
|
||||||
self.stride = conv_stride
|
self.stride = conv_stride
|
||||||
|
self.conv_filters = conv_filters
|
||||||
|
self.conv_kernel = conv_kernel
|
||||||
|
|
||||||
# Modules
|
# Modules
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
|
||||||
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
|
||||||
padding=self.padding, stride=self.stride
|
padding=self.padding, stride=self.stride
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,22 +59,23 @@ class DeConvModule(nn.Module):
|
|||||||
output = self(x)
|
output = self(x)
|
||||||
return output.shape[1:]
|
return output.shape[1:]
|
||||||
|
|
||||||
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
|
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||||
dropout: Union[int, float] = 0, autopad=False,
|
dropout: Union[int, float] = 0, autopad=0,
|
||||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
|
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
||||||
use_bias=True, use_norm=False):
|
use_bias=True, use_norm=False):
|
||||||
super(DeConvModule, self).__init__()
|
super(DeConvModule, self).__init__()
|
||||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||||
self.padding = conv_padding
|
self.padding = conv_padding
|
||||||
|
self.conv_kernel = conv_kernel
|
||||||
self.stride = conv_stride
|
self.stride = conv_stride
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
self.conv_filters = conv_filters
|
self.conv_filters = conv_filters
|
||||||
|
|
||||||
self.autopad = AutoPad() if autopad else lambda x: x
|
self.autopad = AutoPad() if autopad else lambda x: x
|
||||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
|
||||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
|
||||||
padding=self.padding, stride=self.stride)
|
padding=self.padding, stride=self.stride)
|
||||||
|
|
||||||
self.activation = activation() if activation else lambda x: x
|
self.activation = activation() if activation else lambda x: x
|
||||||
|
@ -6,8 +6,6 @@ from torch import nn
|
|||||||
from torch import functional as F
|
from torch import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lib.objects.map import MapStorage
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
|
||||||
@ -27,10 +25,11 @@ class Flatten(nn.Module):
|
|||||||
print(e)
|
print(e)
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def __init__(self, in_shape, to=(-1, )):
|
def __init__(self, in_shape, to=-1):
|
||||||
|
assert isinstance(to, int) or isinstance(to, tuple)
|
||||||
super(Flatten, self).__init__()
|
super(Flatten, self).__init__()
|
||||||
self.in_shape = in_shape
|
self.in_shape = in_shape
|
||||||
self.to = to
|
self.to = (to,) if isinstance(to, int) else to
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.view(x.size(0), *self.to)
|
return x.view(x.size(0), *self.to)
|
||||||
@ -107,7 +106,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
# Data loading
|
# Data loading
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Map Object
|
# Map Object
|
||||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self.shape
|
return self.shape
|
||||||
|
12
main.py
12
main.py
@ -47,19 +47,19 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
|||||||
# Transformations
|
# Transformations
|
||||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||||
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
|
||||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
|
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
|
||||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
||||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
|
main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
|
||||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
|
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
|
||||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user