restructured
This commit is contained in:
parent
1f4edae95c
commit
7b795c2f7b
@ -84,7 +84,7 @@ class TrajData(object):
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
shape_list = list(map(max, zip(*shapes)))
|
||||
if self.mode == 'all_in_map':
|
||||
if self.mode in ['separated_arrays', 'all_in_map']:
|
||||
shape_list[0] += 2
|
||||
return shape_list
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
from statistics import mean
|
||||
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
@ -36,6 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
|
||||
|
||||
loss = (kld_loss + element_wise_loss) / 2
|
||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||
@ -83,7 +82,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
# Todo: Better naming and size in Parameters
|
||||
self.feature_dim = self.hparams.model_param.lat_dim * 10
|
||||
self.feature_mixed_dim = self.feature_dim + self.feature_dim + 1
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
@ -99,81 +97,64 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[0],
|
||||
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[1],
|
||||
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[2],
|
||||
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
|
||||
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2]*2,
|
||||
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=11, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_flat = Flatten(self.map_conv_3.shape)
|
||||
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
|
||||
|
||||
#
|
||||
# Trajectory Encoder
|
||||
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.traj_flat = Flatten(self.traj_conv_3.shape)
|
||||
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
||||
|
||||
#
|
||||
# Mixed Encoder
|
||||
self.mixed_lin = nn.Linear(self.feature_mixed_dim, self.feature_mixed_dim)
|
||||
self.mixed_norm = nn.BatchNorm1d(self.feature_mixed_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||
self.mixed_lin = nn.Linear(self.feature_dim, self.feature_dim)
|
||||
self.mixed_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
self.mu = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)
|
||||
self.mu = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
|
||||
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape))
|
||||
# Todo Fix This Hack!!!!
|
||||
reshape_shape = (1, self.map_conv_3.shape[1], self.map_conv_3.shape[2])
|
||||
|
||||
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
|
||||
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, reshape_shape))
|
||||
|
||||
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=5, conv_stride=1,
|
||||
self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape)
|
||||
|
||||
self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=9, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=0, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=1, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
|
||||
conv_padding=1, conv_kernel=3, conv_stride=1,
|
||||
@ -214,34 +195,33 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
return alt_tensor
|
||||
|
||||
def encode(self, map_array, trajectory, label):
|
||||
map_tensor = self.map_conv_0(map_array)
|
||||
map_tensor = self.map_res_1(map_tensor)
|
||||
map_tensor = self.map_conv_1(map_tensor)
|
||||
map_tensor = self.map_res_2(map_tensor)
|
||||
map_tensor = self.map_conv_2(map_tensor)
|
||||
map_tensor = self.map_res_3(map_tensor)
|
||||
map_tensor = self.map_conv_3(map_tensor)
|
||||
map_tensor = self.map_flat(map_tensor)
|
||||
map_tensor = self.map_lin(map_tensor)
|
||||
label_array = torch.cat([torch.full((1, 1, self.in_shape[1], self.in_shape[2]), x.item())
|
||||
for x in label], dim=0)
|
||||
label_array = self._move_to_model_device(label_array)
|
||||
combined_tensor = torch.cat((map_array, trajectory, label_array), dim=1)
|
||||
combined_tensor = self.map_conv_0(combined_tensor)
|
||||
combined_tensor = self.map_res_1(combined_tensor)
|
||||
combined_tensor = self.map_conv_1(combined_tensor)
|
||||
combined_tensor = self.map_res_2(combined_tensor)
|
||||
combined_tensor = self.map_conv_2(combined_tensor)
|
||||
combined_tensor = self.map_res_3(combined_tensor)
|
||||
combined_tensor = self.map_conv_3(combined_tensor)
|
||||
|
||||
traj_tensor = self.traj_conv_1(trajectory)
|
||||
traj_tensor = self.traj_conv_2(traj_tensor)
|
||||
traj_tensor = self.traj_conv_3(traj_tensor)
|
||||
traj_tensor = self.traj_flat(traj_tensor)
|
||||
traj_tensor = self.traj_lin(traj_tensor)
|
||||
combined_tensor = self.map_flat(combined_tensor)
|
||||
combined_tensor = self.map_lin(combined_tensor)
|
||||
|
||||
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
|
||||
mixed_tensor = self.mixed_norm(mixed_tensor)
|
||||
mixed_tensor = self.activation(mixed_tensor)
|
||||
mixed_tensor = self.mixed_lin(mixed_tensor)
|
||||
mixed_tensor = self.mixed_norm(mixed_tensor)
|
||||
mixed_tensor = self.activation(mixed_tensor)
|
||||
combined_tensor = self.mixed_lin(combined_tensor)
|
||||
|
||||
combined_tensor = self.mixed_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
combined_tensor = self.mixed_lin(combined_tensor)
|
||||
combined_tensor = self.mixed_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
#
|
||||
# Parameter and Sampling
|
||||
mu = self.mu(mixed_tensor)
|
||||
logvar = self.logvar(mixed_tensor)
|
||||
# logvar = torch.clamp(logvar, min=0, max=10)
|
||||
mu = self.mu(combined_tensor)
|
||||
logvar = self.logvar(combined_tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
|
||||
|
4
main.py
4
main.py
@ -21,7 +21,7 @@ warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
_ROOT = Path(__file__).parent
|
||||
|
||||
# Paramter Configuration
|
||||
# Parameter Configuration
|
||||
# =============================================================================
|
||||
# Argument Parser
|
||||
main_arg_parser = ArgumentParser(description="parser for fast-neural-style")
|
||||
@ -52,7 +52,7 @@ main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||
# Model
|
||||
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
|
||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
|
||||
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
|
||||
|
Loading…
x
Reference in New Issue
Block a user