Kurz vorm durchdrehen
This commit is contained in:
@ -27,20 +27,21 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, alternative = batch_xy
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
mse_loss = self.criterion(generated_alternative, alternative)
|
||||
element_wise_loss = self.criterion(generated_alternative, alternative)
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + mse_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss, mse_loss=mse_loss, kld_loss=kld_loss))
|
||||
loss = (kld_loss + element_wise_loss) / 2
|
||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, alternative = batch_xy
|
||||
batch_x, _ = batch_xy
|
||||
map_array, trajectory, label = batch_x
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
@ -48,18 +49,12 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
|
||||
if test:
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
||||
plt.clf()
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig)
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
|
||||
return dict(epoch=self.current_epoch)
|
||||
|
||||
@ -81,69 +76,88 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
if not issubclassed:
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
# Additional Attributes
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
# Todo: Better naming and size in Parameters
|
||||
self.feature_dim = 10
|
||||
self.lat_dim = self.feature_dim + self.feature_dim + 1
|
||||
self.feature_dim = self.hparams.model_param.lat_dim * 10
|
||||
self.feature_mixed_dim = self.feature_dim + self.feature_dim + 1
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.relu = nn.ReLU()
|
||||
self.activation = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
#
|
||||
# Map Encoder
|
||||
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||
conv_filters=self.hparams.model_param.filters[0])
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[0])
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1])
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[1])
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2])
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=3, conv_stride=1,
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[2])
|
||||
conv_padding=1, conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2]*2)
|
||||
conv_filters=self.hparams.model_param.filters[2]*2,
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.map_flat = Flatten(self.map_conv_3.shape)
|
||||
|
||||
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
|
||||
|
||||
#
|
||||
# Trajectory Encoder
|
||||
self.traj_conv_1 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0])
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.traj_conv_2 = ConvModule(self.traj_conv_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0])
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.traj_conv_3 = ConvModule(self.traj_conv_2.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0])
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.traj_flat = Flatten(self.traj_conv_3.shape)
|
||||
|
||||
self.traj_lin = nn.Linear(reduce(mul, self.traj_conv_3.shape), self.feature_dim)
|
||||
|
||||
#
|
||||
# Mixed Encoder
|
||||
self.mixed_lin = nn.Linear(self.lat_dim, self.lat_dim)
|
||||
self.mixed_lin = nn.Linear(self.feature_mixed_dim, self.feature_mixed_dim)
|
||||
self.mixed_norm = nn.BatchNorm1d(self.feature_mixed_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
self.mu = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
||||
self.logvar = nn.Linear(self.lat_dim, self.hparams.model_param.lat_dim)
|
||||
self.mu = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_mixed_dim, self.hparams.model_param.lat_dim)
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
@ -153,13 +167,17 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
self.reshape_to_map = Flatten(reduce(mul, self.traj_conv_3.shape), self.traj_conv_3.shape)
|
||||
|
||||
self.alt_deconv_1 = DeConvModule(self.traj_conv_3.shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=5, conv_stride=1)
|
||||
conv_padding=0, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1)
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=1, conv_kernel=3, conv_stride=1)
|
||||
conv_padding=1, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
|
||||
conv_padding=1, conv_kernel=3, conv_stride=1)
|
||||
conv_padding=1, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
def forward(self, batch_x):
|
||||
#
|
||||
@ -173,7 +191,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
#
|
||||
# Generate
|
||||
alt_tensor = self.generate(z)
|
||||
|
||||
return alt_tensor, z, mu, logvar
|
||||
|
||||
@staticmethod
|
||||
@ -184,13 +201,16 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
def generate(self, z):
|
||||
alt_tensor = self.alt_lin_1(z)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.alt_lin_2(alt_tensor)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.reshape_to_map(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_1(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_2(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_3(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_out(alt_tensor)
|
||||
# alt_tensor = self.sigmoid(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.sigmoid(alt_tensor)
|
||||
return alt_tensor
|
||||
|
||||
def encode(self, map_array, trajectory, label):
|
||||
@ -211,23 +231,26 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
traj_tensor = self.traj_lin(traj_tensor)
|
||||
|
||||
mixed_tensor = torch.cat((map_tensor, traj_tensor, label.float().unsqueeze(-1)), dim=1)
|
||||
mixed_tensor = self.relu(mixed_tensor)
|
||||
mixed_tensor = self.mixed_norm(mixed_tensor)
|
||||
mixed_tensor = self.activation(mixed_tensor)
|
||||
mixed_tensor = self.mixed_lin(mixed_tensor)
|
||||
mixed_tensor = self.relu(mixed_tensor)
|
||||
mixed_tensor = self.mixed_norm(mixed_tensor)
|
||||
mixed_tensor = self.activation(mixed_tensor)
|
||||
|
||||
#
|
||||
# Parameter and Sampling
|
||||
mu = self.mu(mixed_tensor)
|
||||
logvar = self.logvar(mixed_tensor)
|
||||
# logvar = torch.clamp(logvar, min=0, max=10)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
|
||||
def generate_random(self, n=6):
|
||||
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
||||
|
||||
trajectories = [x.get_random_trajectory() for x in maps] * 2
|
||||
trajectories = [x.get_random_trajectory() for x in maps]
|
||||
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
||||
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories]
|
||||
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
|
||||
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
||||
|
||||
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
||||
@ -294,14 +317,14 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
||||
plt.clf()
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig)
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
|
||||
@ -330,4 +353,4 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
self.criterion = nn.BCELoss()
|
||||
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||
|
Reference in New Issue
Block a user