VAE Debugged and Running

This commit is contained in:
Si11ium 2020-03-25 09:39:59 +01:00
parent defa232bf2
commit 934dadb558
5 changed files with 171 additions and 193 deletions

View File

@ -1,5 +1,7 @@
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision.transforms import transforms
import numpy as np import numpy as np
import torch
class MyMNIST(MNIST): class MyMNIST(MNIST):
@ -9,12 +11,12 @@ class MyMNIST(MNIST):
return np.asarray(self.test_dataset[0][0]).shape return np.asarray(self.test_dataset[0][0]).shape
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__('res', train=False, download=True) super(MyMNIST, self).__init__('res', train=False, download=True, transform=transforms.ToTensor())
pass pass
def __getitem__(self, item): def __getitem__(self, item):
image = super(MyMNIST, self).__getitem__(item) image, label = super(MyMNIST, self).__getitem__(item)
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1] return image, label
@property @property
def train_dataset(self): def train_dataset(self):

View File

@ -1,17 +1,16 @@
from functools import reduce from functools import reduce
from operator import mul from operator import mul
from random import choices, choice from random import choice
import torch import torch
from torch import nn from torch import nn
from torch.optim import Adam from torch.optim import Adam
from torchvision.datasets import MNIST
from datasets.mnist import MyMNIST from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData from datasets.trajectory_dataset import TrajData
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule from lib.modules.blocks import ConvModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -20,55 +19,33 @@ from lib.visualization.generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule): class CNNRouteGeneratorModel(LightningBaseModule):
torch.autograd.set_detect_anomaly(True)
name = 'CNNRouteGenerator' name = 'CNNRouteGenerator'
def configure_optimizers(self): def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.train_param.lr) return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs): def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, target = batch_xy batch_x, _ = batch_xy
generated_alternative, z, mu, logvar = self(batch_x) reconstruction, z, mu, logvar = self(batch_x)
target = batch_x if 'ae' in self.hparams.data_param.mode else target
element_wise_loss = self.criterion(generated_alternative, target)
if 'vae' in self.hparams.data_param.mode: recon_loss = self.criterion(reconstruction, batch_x)
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
loss = kld_loss + element_wise_loss kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
else:
loss = element_wise_loss loss = recon_loss + kldivergence
kld_loss = 0 return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence))
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
def _test_val_step(self, batch_xy, batch_nb, *args): def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, _ = batch_xy batch_x, _ = batch_xy
if 'vae' in self.hparams.data_param.mode:
z, mu, logvar = self.encode(batch_x)
else:
z = self.encode(batch_x)
mu, logvar = z, z
generated_alternative = self.generate(mu) mu, logvar = self.encoder(batch_x)
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar) z = self.reparameterize(mu, logvar)
reconstruction = self.decoder(mu)
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar)
if 'hom' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
elif 'alt' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
elif 'vae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY) labels = torch.full((batch_x.shape[0], 1), V.ANY)
elif 'ae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY)
else:
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
return_dict.update(labels=self._move_to_model_device(labels)) return_dict.update(labels=self._move_to_model_device(labels))
return return_dict return return_dict
@ -87,12 +64,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
return dict(epoch=self.current_epoch) return dict(epoch=self.current_epoch)
def on_epoch_start(self):
# self.dataset.seed(self.logger.version)
# torch.random.manual_seed(self.logger.version)
# np.random.seed(self.logger.version)
pass
def validation_step(self, *args): def validation_step(self, *args):
return self._test_val_step(*args) return self._test_val_step(*args)
@ -113,160 +84,163 @@ class CNNRouteGeneratorModel(LightningBaseModule):
self.dataset = TrajData(self.hparams.data_param.map_root, self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode, mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed, preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length, normalized=True) length=self.hparams.data_param.dataset_length)
self.criterion = nn.MSELoss() self.criterion = nn.BCELoss(reduction='sum')
self.dataset = MyMNIST() self.dataset = MyMNIST()
# Additional Attributes # # Additional Attributes
####################################################### ###################################################
self.in_shape = self.dataset.map_shapes_max self.in_shape = self.dataset.map_shapes_max
self.use_res_net = self.hparams.model_param.use_res_net self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim self.feature_dim = self.lat_dim
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0] self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
########################################################
# NN Nodes
###################################################
self.encoder = Encoder(self.in_shape, self.hparams)
self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams)
def forward(self, batch_x):
# Encode
mu, logvar = self.encoder(batch_x)
# Bottleneck
z = self.reparameterize(mu, logvar)
# Decode
reconstruction = self.decoder(z)
return reconstruction, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = 0.5 * torch.exp(logvar)
eps = torch.randn_like(mu)
z = mu + std * eps
return z
class Encoder(nn.Module):
def __init__(self, in_shape, hparams):
super(Encoder, self).__init__()
# Params
###################################################
self.hparams = hparams
# Additional Attributes
###################################################
self.in_shape = in_shape
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim * 10
# NN Nodes # NN Nodes
################################################### ###################################################
# #
# Utils # Utils
self.activation = nn.LeakyReLU() self.activation = self.hparams.activation()
self.sigmoid = nn.Sigmoid()
# #
# Map Encoder # Encoder
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0], conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm, use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias) use_bias=self.hparams.model_param.use_bias)
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_padding=2, conv_filters=self.hparams.model_param.filters[0], conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm, use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias) use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1], conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm, use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias) use_bias=self.hparams.model_param.use_bias)
self.enc_res_2 = ResidualModule(self.enc_conv_1a.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2], conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm, use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias) use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3a = ConvModule(self.enc_conv_2a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, self.last_conv_shape = self.conv_3.shape
conv_filters=self.hparams.model_param.filters[2], self.flat = Flatten(in_shape=self.last_conv_shape)
use_norm=self.hparams.model_param.use_norm, self.lin = nn.Linear(self.flat.shape, self.feature_dim)
use_bias=self.hparams.model_param.use_bias)
last_conv_shape = self.enc_conv_3a.shape
self.enc_flat = Flatten(last_conv_shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
#
# Mixed Encoder
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
# #
# Variational Bottleneck # Variational Bottleneck
if 'vae' in self.hparams.data_param.mode:
self.mu = nn.Linear(self.feature_dim, self.lat_dim) self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim) self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
#
# Linear Bottleneck
else:
self.z = nn.Linear(self.feature_dim, self.lat_dim)
#
# Alternative Generator
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[0],
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_3a.shape, self.out_channels, activation=None,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, batch_x): def forward(self, batch_x):
# tensor = self.conv_0(batch_x)
# Encode tensor = self.conv_1(tensor)
if 'vae' in self.hparams.data_param.mode: tensor = self.conv_2(tensor)
z, mu, logvar = self.encode(batch_x) tensor = self.conv_3(tensor)
else:
z = self.encode(batch_x)
mu, logvar = z, z
# tensor = self.flat(tensor)
# Generate tensor = self.lin(tensor)
alt_tensor = self.generate(z) tensor = self.activation(tensor)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def encode(self, batch_x):
combined_tensor = self.enc_conv_0(batch_x)
combined_tensor = self.enc_conv_1a(combined_tensor)
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_3a(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.activation(combined_tensor)
# #
# Variational # Variational
# Parameter and Sampling # Parameter for Sampling
if 'vae' in self.hparams.data_param.mode: mu = self.mu(tensor)
mu = self.mu(combined_tensor) logvar = self.logvar(tensor)
logvar = self.logvar(combined_tensor) return mu, logvar
z = self.reparameterize(mu, logvar)
return z, mu, logvar
else: class Decoder(nn.Module):
def __init__(self, out_channels, last_conv_shape, hparams):
super(Decoder, self).__init__()
# Params
###################################################
self.hparams = hparams
# Additional Attributes
###################################################
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim
self.out_channels = out_channels
# NN Nodes
###################################################
# #
# Linear Bottleneck # Utils
z = self.z(combined_tensor) self.activation = self.hparams.activation()
return z
def generate(self, z): #
alt_tensor = self.gen_lin_1(z) # Alternative Generator
alt_tensor = self.activation(alt_tensor) self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape))
alt_tensor = self.reshape_to_last_conv(alt_tensor) self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape)
alt_tensor = self.gen_deconv_1a(alt_tensor)
alt_tensor = self.gen_deconv_2a(alt_tensor) self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
alt_tensor = self.gen_deconv_3a(alt_tensor) self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
alt_tensor = self.gen_deconv_out(alt_tensor) self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0],
# alt_tensor = self.activation(alt_tensor) conv_padding=0, conv_kernel=3, conv_stride=1,
# alt_tensor = self.sigmoid(alt_tensor) use_norm=self.hparams.model_param.use_norm)
return alt_tensor
self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, z):
tensor = self.lin(z)
tensor = self.activation(tensor)
tensor = self.reshape(tensor)
tensor = self.deconv_1(tensor)
tensor = self.deconv_2(tensor)
tensor = self.deconv_3(tensor)
reconstruction = self.deconv_out(tensor)
return reconstruction

View File

@ -17,9 +17,9 @@ class ConvModule(nn.Module):
output = self(x) output = self(x)
return output.shape[1:] return output.shape[1:]
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False, def __init__(self, in_shape, conv_filters, conv_kernel, activation: nn.Module = nn.ELU, pooling_size=None,
dropout: Union[int, float] = 0, conv_class=nn.Conv2d, use_bias=True, use_norm=False, dropout: Union[int, float] = 0,
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0): conv_class=nn.Conv2d, conv_stride=1, conv_padding=0):
super(ConvModule, self).__init__() super(ConvModule, self).__init__()
# Module Parameters # Module Parameters
@ -30,12 +30,14 @@ class ConvModule(nn.Module):
# Convolution Parameters # Convolution Parameters
self.padding = conv_padding self.padding = conv_padding
self.stride = conv_stride self.stride = conv_stride
self.conv_filters = conv_filters
self.conv_kernel = conv_kernel
# Modules # Modules
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x self.pooling = nn.MaxPool2d(pooling_size) if pooling_size else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
self.conv = conv_class(in_channels, conv_filters, conv_kernel, bias=use_bias, self.conv = conv_class(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride padding=self.padding, stride=self.stride
) )
@ -57,22 +59,23 @@ class DeConvModule(nn.Module):
output = self(x) output = self(x)
return output.shape[1:] return output.shape[1:]
def __init__(self, in_shape, conv_filters=3, conv_kernel=5, conv_stride=1, conv_padding=0, def __init__(self, in_shape, conv_filters, conv_kernel, conv_stride=1, conv_padding=0,
dropout: Union[int, float] = 0, autopad=False, dropout: Union[int, float] = 0, autopad=0,
activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=None, activation: Union[None, nn.Module] = nn.ReLU, interpolation_scale=0,
use_bias=True, use_norm=False): use_bias=True, use_norm=False):
super(DeConvModule, self).__init__() super(DeConvModule, self).__init__()
in_channels, height, width = in_shape[0], in_shape[1], in_shape[2] in_channels, height, width = in_shape[0], in_shape[1], in_shape[2]
self.padding = conv_padding self.padding = conv_padding
self.conv_kernel = conv_kernel
self.stride = conv_stride self.stride = conv_stride
self.in_shape = in_shape self.in_shape = in_shape
self.conv_filters = conv_filters self.conv_filters = conv_filters
self.autopad = AutoPad() if autopad else lambda x: x self.autopad = AutoPad() if autopad else lambda x: x
self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x self.interpolation = Interpolate(scale_factor=interpolation_scale) if interpolation_scale else lambda x: x
self.norm = nn.BatchNorm2d(in_channels, eps=1e-04, affine=False) if use_norm else lambda x: x self.norm = nn.BatchNorm2d(in_channels, eps=1e-04) if use_norm else lambda x: x
self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x self.dropout = nn.Dropout2d(dropout) if dropout else lambda x: x
self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, conv_kernel, bias=use_bias, self.de_conv = nn.ConvTranspose2d(in_channels, self.conv_filters, self.conv_kernel, bias=use_bias,
padding=self.padding, stride=self.stride) padding=self.padding, stride=self.stride)
self.activation = activation() if activation else lambda x: x self.activation = activation() if activation else lambda x: x

View File

@ -6,8 +6,6 @@ from torch import nn
from torch import functional as F from torch import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lib.objects.map import MapStorage
import pytorch_lightning as pl import pytorch_lightning as pl
@ -27,10 +25,11 @@ class Flatten(nn.Module):
print(e) print(e)
return -1 return -1
def __init__(self, in_shape, to=(-1, )): def __init__(self, in_shape, to=-1):
assert isinstance(to, int) or isinstance(to, tuple)
super(Flatten, self).__init__() super(Flatten, self).__init__()
self.in_shape = in_shape self.in_shape = in_shape
self.to = to self.to = (to,) if isinstance(to, int) else to
def forward(self, x): def forward(self, x):
return x.view(x.size(0), *self.to) return x.view(x.size(0), *self.to)
@ -107,7 +106,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
# Data loading # Data loading
# ============================================================================= # =============================================================================
# Map Object # Map Object
self.map_storage = MapStorage(self.hparams.data_param.map_root) # self.map_storage = MapStorage(self.hparams.data_param.map_root)
def size(self): def size(self):
return self.shape return self.shape

12
main.py
View File

@ -47,19 +47,19 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
# Transformations # Transformations
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="") main_arg_parser.add_argument("--train_epochs", type=int, default=500, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=200, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") main_arg_parser.add_argument("--train_lr", type=float, default=1e-3, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
# Model # Model
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="") main_arg_parser.add_argument("--model_activation", type=str, default="leaky_relu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=16, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="") main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")