VAE Debugged and Running
This commit is contained in:
parent
defa232bf2
commit
934dadb558
@ -1,5 +1,7 @@
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import transforms
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class MyMNIST(MNIST):
|
||||
@ -9,12 +11,12 @@ class MyMNIST(MNIST):
|
||||
return np.asarray(self.test_dataset[0][0]).shape
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MyMNIST, self).__init__('res', train=False, download=True)
|
||||
super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
image = super(MyMNIST, self).__getitem__(item)
|
||||
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
|
||||
image, label = super(MyMNIST, self).__getitem__(item)
|
||||
return image, label
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
|
@ -1,17 +1,16 @@
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from random import choices, choice
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
from datasets.mnist import MyMNIST
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.blocks import ConvModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -20,55 +19,33 @@ from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
name = 'CNNRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, target = batch_xy
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
target = batch_x if 'ae' in self.hparams.data_param.mode else target
|
||||
element_wise_loss = self.criterion(generated_alternative, target)
|
||||
batch_x, _ = batch_xy
|
||||
reconstruction, z, mu, logvar = self(batch_x)
|
||||
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
recon_loss = self.criterion(reconstruction, batch_x)
|
||||
|
||||
loss = kld_loss + element_wise_loss
|
||||
else:
|
||||
loss = element_wise_loss
|
||||
kld_loss = 0
|
||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
|
||||
loss = recon_loss + kldivergence
|
||||
return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence))
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, _ = batch_xy
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
z, mu, logvar = self.encode(batch_x)
|
||||
else:
|
||||
z = self.encode(batch_x)
|
||||
mu, logvar = z, z
|
||||
|
||||
generated_alternative = self.generate(mu)
|
||||
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar)
|
||||
mu, logvar = self.encoder(batch_x)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
|
||||
if 'hom' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
|
||||
elif 'alt' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
|
||||
elif 'vae' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
||||
elif 'ae' in self.hparams.data_param.mode:
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
||||
else:
|
||||
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
|
||||
reconstruction = self.decoder(mu)
|
||||
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar)
|
||||
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
||||
|
||||
return_dict.update(labels=self._move_to_model_device(labels))
|
||||
return return_dict
|
||||
@ -87,12 +64,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
return dict(epoch=self.current_epoch)
|
||||
|
||||
def on_epoch_start(self):
|
||||
# self.dataset.seed(self.logger.version)
|
||||
# torch.random.manual_seed(self.logger.version)
|
||||
# np.random.seed(self.logger.version)
|
||||
pass
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
@ -113,160 +84,163 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root,
|
||||
mode=self.hparams.data_param.mode,
|
||||
preprocessed=self.hparams.data_param.use_preprocessed,
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||
self.criterion = nn.MSELoss()
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
self.criterion = nn.BCELoss(reduction='sum')
|
||||
|
||||
self.dataset = MyMNIST()
|
||||
|
||||
# Additional Attributes #
|
||||
#######################################################
|
||||
# Additional Attributes
|
||||
###################################################
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
self.use_res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim
|
||||
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
|
||||
########################################################
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
self.encoder = Encoder(self.in_shape, self.hparams)
|
||||
self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams)
|
||||
|
||||
def forward(self, batch_x):
|
||||
# Encode
|
||||
mu, logvar = self.encoder(batch_x)
|
||||
|
||||
# Bottleneck
|
||||
z = self.reparameterize(mu, logvar)
|
||||
|
||||
# Decode
|
||||
reconstruction = self.decoder(z)
|
||||
return reconstruction, z, mu, logvar
|
||||
|
||||
@staticmethod
|
||||
def reparameterize(mu, logvar):
|
||||
std = 0.5 * torch.exp(logvar)
|
||||
eps = torch.randn_like(mu)
|
||||
z = mu + std * eps
|
||||
return z
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_shape, hparams):
|
||||
super(Encoder, self).__init__()
|
||||
# Params
|
||||
###################################################
|
||||
self.hparams = hparams
|
||||
|
||||
# Additional Attributes
|
||||
###################################################
|
||||
self.in_shape = in_shape
|
||||
self.use_res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim * 10
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.activation = nn.LeakyReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.activation = self.hparams.activation()
|
||||
|
||||
#
|
||||
# Map Encoder
|
||||
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
# Encoder
|
||||
self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
last_conv_shape = self.enc_conv_3a.shape
|
||||
self.enc_flat = Flatten(last_conv_shape)
|
||||
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
||||
|
||||
#
|
||||
# Mixed Encoder
|
||||
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
|
||||
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||
self.last_conv_shape = self.conv_3.shape
|
||||
self.flat = Flatten(in_shape=self.last_conv_shape)
|
||||
self.lin = nn.Linear(self.flat.shape, self.feature_dim)
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
|
||||
#
|
||||
# Linear Bottleneck
|
||||
else:
|
||||
self.z = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
|
||||
|
||||
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
||||
|
||||
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
|
||||
|
||||
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
|
||||
def forward(self, batch_x):
|
||||
#
|
||||
# Encode
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
z, mu, logvar = self.encode(batch_x)
|
||||
else:
|
||||
z = self.encode(batch_x)
|
||||
mu, logvar = z, z
|
||||
tensor = self.conv_0(batch_x)
|
||||
tensor = self.conv_1(tensor)
|
||||
tensor = self.conv_2(tensor)
|
||||
tensor = self.conv_3(tensor)
|
||||
|
||||
#
|
||||
# Generate
|
||||
alt_tensor = self.generate(z)
|
||||
return alt_tensor, z, mu, logvar
|
||||
|
||||
@staticmethod
|
||||
def reparameterize(mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def encode(self, batch_x):
|
||||
combined_tensor = self.enc_conv_0(batch_x)
|
||||
combined_tensor = self.enc_conv_1a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_2a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_flat(combined_tensor)
|
||||
combined_tensor = self.enc_lin_1(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
tensor = self.flat(tensor)
|
||||
tensor = self.lin(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
|
||||
#
|
||||
# Variational
|
||||
# Parameter and Sampling
|
||||
if 'vae' in self.hparams.data_param.mode:
|
||||
mu = self.mu(combined_tensor)
|
||||
logvar = self.logvar(combined_tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
else:
|
||||
#
|
||||
# Linear Bottleneck
|
||||
z = self.z(combined_tensor)
|
||||
return z
|
||||
# Parameter for Sampling
|
||||
mu = self.mu(tensor)
|
||||
logvar = self.logvar(tensor)
|
||||
return mu, logvar
|
||||
|
||||
def generate(self, z):
|
||||
alt_tensor = self.gen_lin_1(z)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
|
||||
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
||||
class Decoder(nn.Module):
|
||||
|
||||
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
||||
def __init__(self, out_channels, last_conv_shape, hparams):
|
||||
super(Decoder, self).__init__()
|
||||
# Params
|
||||
###################################################
|
||||
self.hparams = hparams
|
||||
|
||||
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
# Additional Attributes
|
||||
###################################################
|
||||
self.use_res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim
|
||||
self.out_channels = out_channels
|
||||
|
||||
alt_tensor = self.gen_deconv_out(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
# alt_tensor = self.sigmoid(alt_tensor)
|
||||
return alt_tensor
|
||||
# NN Nodes
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.activation = self.hparams.activation()
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape))
|
||||
|
||||
self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape)
|
||||
|
||||
self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid,
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
def forward(self, z):
|
||||
|
||||
tensor = self.lin(z)
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.reshape(tensor)
|
||||
|
||||
tensor = self.deconv_1(tensor)
|
||||
tensor = self.deconv_2(tensor)
|
||||
tensor = self.deconv_3(tensor)
|
||||
reconstruction = self.deconv_out(tensor)
|
||||
return reconstruction
|
||||
|
@ -17,9 +17,9 @@ class ConvModule(nn.Module):
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False,
|
||||
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
|
||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
|
||||
use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
|
||||
conv_class=nn.Conv2d, conv_stride=1, conv_padding=0):
|
||||
super(ConvModule, self).__init__()
|
||||
|
||||
# Module Parameters
|
||||
@ -30,12 +30,14 @@ class ConvModule(nn.Module):
|
||||
# Convolution Parameters
|
||||
self.padding = conv_padding
|
||||
self.stride = conv_stride
|
||||
self.conv_filters = conv_filters
|
||||
self.conv_kernel = conv_kernel
|
||||
|
||||
# Modules
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
||||
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias,
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
|
||||
self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride
|
||||
)
|
||||
|
||||
@ -57,22 +59,23 @@ class DeConvModule(nn.Module):
|
||||
output = self(x)
|
||||
return output.shape[1:]
|
||||
|
||||
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
dropout: Union[int, float] = 0, autopad=False,
|
||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None,
|
||||
def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
|
||||
dropout: Union[int, float] = 0, autopad=0,
|
||||
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
|
||||
use_bias=True, use_norm=False):
|
||||
super(DeConvModule, self).__init__()
|
||||
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
|
||||
self.padding = conv_padding
|
||||
self.conv_kernel = conv_kernel
|
||||
self.stride = conv_stride
|
||||
self.in_shape = in_shape
|
||||
self.conv_filters = conv_filters
|
||||
|
||||
self.autopad = AutoPad() if autopad else lambda x: x
|
||||
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x
|
||||
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
|
||||
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
|
||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias,
|
||||
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
|
||||
padding=self.padding, stride=self.stride)
|
||||
|
||||
self.activation = activation() if activation else lambda x: x
|
||||
|
@ -6,8 +6,6 @@ from torch import nn
|
||||
from torch import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lib.objects.map import MapStorage
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
@ -27,10 +25,11 @@ class Flatten(nn.Module):
|
||||
print(e)
|
||||
return -1
|
||||
|
||||
def __init__(self, in_shape, to=(-1, )):
|
||||
def __init__(self, in_shape, to=-1):
|
||||
assert isinstance(to, int) or isinstance(to, tuple)
|
||||
super(Flatten, self).__init__()
|
||||
self.in_shape = in_shape
|
||||
self.to = to
|
||||
self.to = (to,) if isinstance(to, int) else to
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), *self.to)
|
||||
@ -107,7 +106,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
||||
# Data loading
|
||||
# =============================================================================
|
||||
# Map Object
|
||||
self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
# self.map_storage = MapStorage(self.hparams.data_param.map_root)
|
||||
|
||||
def size(self):
|
||||
return self.shape
|
||||
|
12
main.py
12
main.py
@ -47,19 +47,19 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
|
||||
# Transformations
|
||||
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
|
||||
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
|
||||
main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
|
||||
main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
|
||||
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
|
||||
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
|
||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
|
||||
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
|
||||
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user