diff --git a/datasets/trajectory_dataset.py b/datasets/trajectory_dataset.py index 19ef289..ff567b6 100644 --- a/datasets/trajectory_dataset.py +++ b/datasets/trajectory_dataset.py @@ -84,7 +84,7 @@ class TrajData(object): def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) - if self.mode == 'all_in_map': + if self.mode in ['separated_arrays', 'all_in_map']: shape_list[0] += 2 return shape_list diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index 82cbac8..3383029 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -1,5 +1,3 @@ -from statistics import mean - from random import choice import torch @@ -36,6 +34,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # Dimensional Resizing TODO: Does This make sense? Sanity Check it! # kld_loss /= reduce(mul, self.in_shape) + kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size loss = (kld_loss + element_wise_loss) / 2 return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) @@ -83,7 +82,6 @@ class CNNRouteGeneratorModel(LightningBaseModule): self.in_shape = self.dataset.map_shapes_max # Todo: Better naming and size in Parameters self.feature_dim = self.hparams.model_param.lat_dim * 10 - self.feature_mixed_dim = self.feature_dim + self.feature_dim + 1 # NN Nodes ################################################### @@ -99,81 +97,64 @@ class CNNRouteGeneratorModel(LightningBaseModule): use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1, - conv_padding=1, conv_filters=self.hparams.model_param.filters[0], + self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, + conv_padding=2, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, + self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1, - conv_padding=1, conv_filters=self.hparams.model_param.filters[1], + self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, + conv_padding=2, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0, + self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1, - conv_padding=1, conv_filters=self.hparams.model_param.filters[2], + self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=7, conv_stride=1, + conv_padding=3, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[2]*2, + self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=11, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) self.map_flat = Flatten(self.map_conv_3.shape) self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim) - # - # Trajectory Encoder - self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[0], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - - self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[0], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - - self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[0], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - - self.traj_flat = Flatten(self.traj_conv_3.shape) - self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim) - # # Mixed Encoder - self.mixed_lin = nn.Linear(self.feature_mixed_dim, self.feature_mixed_dim) - self.mixed_norm = nn.BatchNorm1d(self.feature_mixed_dim) if self.hparams.model_param.use_norm else lambda x: x + self.mixed_lin = nn.Linear(self.feature_dim, self.feature_dim) + self.mixed_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x # # Variational Bottleneck - self.mu = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim) - self.logvar = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim) + self.mu = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim) + self.logvar = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim) # # Alternative Generator self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) - self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, self.traj_conv_3.shape)) + # Todo Fix This Hack!!!! + reshape_shape = (1, self.map_conv_3.shape[1], self.map_conv_3.shape[2]) - self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape) + self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, reshape_shape)) - self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2], - conv_padding=0, conv_kernel=5, conv_stride=1, + self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape) + + self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2], + conv_padding=0, conv_kernel=9, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1], - conv_padding=0, conv_kernel=3, conv_stride=1, + conv_padding=0, conv_kernel=5, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0], - conv_padding=1, conv_kernel=3, conv_stride=1, + conv_padding=1, conv_kernel=5, conv_stride=1, use_norm=self.hparams.model_param.use_norm) self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None, conv_padding=1, conv_kernel=3, conv_stride=1, @@ -214,34 +195,33 @@ class CNNRouteGeneratorModel(LightningBaseModule): return alt_tensor def encode(self, map_array, trajectory, label): - map_tensor = self.map_conv_0(map_array) - map_tensor = self.map_res_1(map_tensor) - map_tensor = self.map_conv_1(map_tensor) - map_tensor = self.map_res_2(map_tensor) - map_tensor = self.map_conv_2(map_tensor) - map_tensor = self.map_res_3(map_tensor) - map_tensor = self.map_conv_3(map_tensor) - map_tensor = self.map_flat(map_tensor) - map_tensor = self.map_lin(map_tensor) + label_array = torch.cat([torch.full((1, 1, self.in_shape[1], self.in_shape[2]), x.item()) + for x in label], dim=0) + label_array = self._move_to_model_device(label_array) + combined_tensor = torch.cat((map_array, trajectory, label_array), dim=1) + combined_tensor = self.map_conv_0(combined_tensor) + combined_tensor = self.map_res_1(combined_tensor) + combined_tensor = self.map_conv_1(combined_tensor) + combined_tensor = self.map_res_2(combined_tensor) + combined_tensor = self.map_conv_2(combined_tensor) + combined_tensor = self.map_res_3(combined_tensor) + combined_tensor = self.map_conv_3(combined_tensor) - traj_tensor = self.traj_conv_1(trajectory) - traj_tensor = self.traj_conv_2(traj_tensor) - traj_tensor = self.traj_conv_3(traj_tensor) - traj_tensor = self.traj_flat(traj_tensor) - traj_tensor = self.traj_lin(traj_tensor) + combined_tensor = self.map_flat(combined_tensor) + combined_tensor = self.map_lin(combined_tensor) - mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1) - mixed_tensor = self.mixed_norm(mixed_tensor) - mixed_tensor = self.activation(mixed_tensor) - mixed_tensor = self.mixed_lin(mixed_tensor) - mixed_tensor = self.mixed_norm(mixed_tensor) - mixed_tensor = self.activation(mixed_tensor) + combined_tensor = self.mixed_lin(combined_tensor) + + combined_tensor = self.mixed_norm(combined_tensor) + combined_tensor = self.activation(combined_tensor) + combined_tensor = self.mixed_lin(combined_tensor) + combined_tensor = self.mixed_norm(combined_tensor) + combined_tensor = self.activation(combined_tensor) # # Parameter and Sampling - mu = self.mu(mixed_tensor) - logvar = self.logvar(mixed_tensor) - # logvar = torch.clamp(logvar, min=0, max=10) + mu = self.mu(combined_tensor) + logvar = self.logvar(combined_tensor) z = self.reparameterize(mu, logvar) return z, mu, logvar diff --git a/main.py b/main.py index c20bc26..7e4e522 100644 --- a/main.py +++ b/main.py @@ -21,7 +21,7 @@ warnings.filterwarnings('ignore', category=UserWarning) _ROOT = Path(__file__).parent -# Paramter Configuration +# Parameter Configuration # ============================================================================= # Argument Parser main_arg_parser = ArgumentParser(description="parser for fast-neural-style") @@ -52,7 +52,7 @@ main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") # Model main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="") -main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="") +main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")