Refactoring
This commit is contained in:
0
models/generators/__init__.py
Normal file
0
models/generators/__init__.py
Normal file
244
models/generators/cnn.py
Normal file
244
models/generators/cnn.py
Normal file
@ -0,0 +1,244 @@
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from datasets.mnist import MyMNIST
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.modules.blocks import ConvModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import lib.variables as V
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
name = 'CNNRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, _ = batch_xy
|
||||
reconstruction, z, mu, logvar = self(batch_x)
|
||||
|
||||
recon_loss = self.criterion(reconstruction, batch_x)
|
||||
|
||||
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
|
||||
loss = recon_loss + kldivergence
|
||||
return dict(loss=loss, log=dict(reconstruction_loss=recon_loss, loss=loss, kld_loss=kldivergence))
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, _ = batch_xy
|
||||
|
||||
mu, logvar = self.encoder(batch_x)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
|
||||
reconstruction = self.decoder(mu)
|
||||
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=reconstruction, z=z, mu=mu, logvar=logvar)
|
||||
|
||||
labels = torch.full((batch_x.shape[0], 1), V.ANY)
|
||||
|
||||
return_dict.update(labels=self._move_to_model_device(labels))
|
||||
return return_dict
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
plt.close('all')
|
||||
|
||||
g = GeneratorVisualizer(choice(outputs))
|
||||
fig = g.draw_io_bundle()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
fig = g.draw_latent()
|
||||
self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
return dict(epoch=self.current_epoch)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
def __init__(self, *params, issubclassed=False):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root,
|
||||
mode=self.hparams.data_param.mode,
|
||||
preprocessed=self.hparams.data_param.use_preprocessed,
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
|
||||
self.criterion = nn.BCELoss(reduction='sum')
|
||||
|
||||
# Additional Attributes
|
||||
###################################################
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
self.use_res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim
|
||||
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
self.encoder = Encoder(self.in_shape, self.hparams)
|
||||
self.decoder = Decoder(self.out_channels, self.encoder.last_conv_shape, self.hparams)
|
||||
|
||||
def forward(self, batch_x):
|
||||
# Encode
|
||||
mu, logvar = self.encoder(batch_x)
|
||||
|
||||
# Bottleneck
|
||||
z = self.reparameterize(mu, logvar)
|
||||
|
||||
# Decode
|
||||
reconstruction = self.decoder(z)
|
||||
return reconstruction, z, mu, logvar
|
||||
|
||||
@staticmethod
|
||||
def reparameterize(mu, logvar):
|
||||
std = 0.5 * torch.exp(logvar)
|
||||
eps = torch.randn_like(mu)
|
||||
z = mu + std * eps
|
||||
return z
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_shape, hparams):
|
||||
super(Encoder, self).__init__()
|
||||
# Params
|
||||
###################################################
|
||||
self.hparams = hparams
|
||||
|
||||
# Additional Attributes
|
||||
###################################################
|
||||
self.in_shape = in_shape
|
||||
self.use_res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim * 10
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.activation = self.hparams.activation()
|
||||
|
||||
#
|
||||
# Encoder
|
||||
self.conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.conv_1 = ConvModule(self.conv_0.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.conv_2 = ConvModule(self.conv_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.conv_3 = ConvModule(self.conv_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.last_conv_shape = self.conv_3.shape
|
||||
self.flat = Flatten(in_shape=self.last_conv_shape)
|
||||
self.lin = nn.Linear(self.flat.shape, self.feature_dim)
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
|
||||
def forward(self, batch_x):
|
||||
tensor = self.conv_0(batch_x)
|
||||
tensor = self.conv_1(tensor)
|
||||
tensor = self.conv_2(tensor)
|
||||
tensor = self.conv_3(tensor)
|
||||
|
||||
tensor = self.flat(tensor)
|
||||
tensor = self.lin(tensor)
|
||||
tensor = self.activation(tensor)
|
||||
|
||||
#
|
||||
# Variational
|
||||
# Parameter for Sampling
|
||||
mu = self.mu(tensor)
|
||||
logvar = self.logvar(tensor)
|
||||
return mu, logvar
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self, out_channels, last_conv_shape, hparams):
|
||||
super(Decoder, self).__init__()
|
||||
# Params
|
||||
###################################################
|
||||
self.hparams = hparams
|
||||
|
||||
# Additional Attributes
|
||||
###################################################
|
||||
self.use_res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim
|
||||
self.out_channels = out_channels
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.activation = self.hparams.activation()
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
self.lin = nn.Linear(self.lat_dim, reduce(mul, last_conv_shape))
|
||||
|
||||
self.reshape = Flatten(in_shape=reduce(mul, last_conv_shape), to=last_conv_shape)
|
||||
|
||||
self.deconv_1 = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
|
||||
conv_padding=0, conv_kernel=7, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.deconv_2 = DeConvModule(self.deconv_1.shape, self.hparams.model_param.filters[1],
|
||||
conv_padding=1, conv_kernel=5, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.deconv_3 = DeConvModule(self.deconv_2.shape, self.hparams.model_param.filters[0],
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
self.deconv_out = DeConvModule(self.deconv_3.shape, self.out_channels, activation=nn.Sigmoid,
|
||||
conv_padding=0, conv_kernel=3, conv_stride=1,
|
||||
use_norm=self.hparams.model_param.use_norm)
|
||||
|
||||
def forward(self, z):
|
||||
|
||||
tensor = self.lin(z)
|
||||
tensor = self.activation(tensor)
|
||||
tensor = self.reshape(tensor)
|
||||
|
||||
tensor = self.deconv_1(tensor)
|
||||
tensor = self.deconv_2(tensor)
|
||||
tensor = self.deconv_3(tensor)
|
||||
reconstruction = self.deconv_out(tensor)
|
||||
return reconstruction
|
116
models/generators/cnn_discriminated.py
Normal file
116
models/generators/cnn_discriminated.py
Normal file
@ -0,0 +1,116 @@
|
||||
from random import choices, seed
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
|
||||
name = 'CNNRouteGeneratorDiscriminated'
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
discriminated_bce_loss=discriminated_bce_loss,
|
||||
kld_loss=kld_loss)
|
||||
)
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
if self._disc is None:
|
||||
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||
return self._disc
|
||||
|
||||
def set_discriminator(self, disc_model):
|
||||
if self._disc is not None:
|
||||
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||
self._disc = disc_model
|
||||
|
||||
def __init__(self, *params):
|
||||
raise NotImplementedError
|
||||
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||
|
||||
self._disc = None
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
55
models/generators/full.py
Normal file
55
models/generators/full.py
Normal file
@ -0,0 +1,55 @@
|
||||
from lib.modules.losses import BinaryHomotopicLoss
|
||||
from lib.modules.utils import LightningBaseModule
|
||||
from lib.objects.map import Map
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LinearRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
pass
|
||||
|
||||
name = 'LinearRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
pass
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def validation_end(self, outputs):
|
||||
pass
|
||||
|
||||
def training_step(self, batch, batch_nb, *args, **kwargs):
|
||||
# Type Annotation
|
||||
traj_x: Trajectory
|
||||
traj_o: Trajectory
|
||||
label_x: int
|
||||
map_name: str
|
||||
map_x: Map
|
||||
# Batch unpacking
|
||||
traj_x, traj_o, label_x, map_name = batch
|
||||
map_x = self.map_storage[map_name]
|
||||
pred_y = self(map_x, traj_x, label_x)
|
||||
|
||||
loss = self.loss(traj_x, pred_y)
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, batch_y = batch_xy
|
||||
pred_y = self(batch_x)
|
||||
loss = self.criterion(pred_y, batch_y.unsqueeze(-1).float())
|
||||
|
||||
return dict(loss=loss, log=dict(loss=loss))
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __init__(self, *params):
|
||||
super(LinearRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
self.criterion = BinaryHomotopicLoss(self.map_storage)
|
||||
|
||||
def forward(self, map_x, traj_x, label_x):
|
||||
pass
|
348
models/generators/recurrent.py
Normal file
348
models/generators/recurrent.py
Normal file
@ -0,0 +1,348 @@
|
||||
from random import choice
|
||||
|
||||
import torch
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from datasets.trajectory_dataset import TrajData
|
||||
from lib.evaluation.classification import ROCEvaluation
|
||||
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
|
||||
from lib.modules.utils import LightningBaseModule, Flatten
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
|
||||
name = 'CNNRouteGenerator'
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, alternative = batch_xy
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
element_wise_loss = self.criterion(generated_alternative, alternative)
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
# kld_loss /= reduce(mul, self.in_shape)
|
||||
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
|
||||
|
||||
loss = (kld_loss + element_wise_loss) / 2
|
||||
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, alternative = batch_xy
|
||||
map_array = batch_x[0]
|
||||
trajectory = batch_x[1]
|
||||
label = batch_x[2].max()
|
||||
|
||||
z, _, _ = self.encode(batch_x)
|
||||
generated_alternative = self.generate(z)
|
||||
|
||||
return dict(map_array=map_array, trajectory=trajectory, batch_nb=batch_nb, label=label,
|
||||
generated_alternative=generated_alternative, pred_label=-1, alternative=alternative
|
||||
)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
return dict(epoch=self.current_epoch)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
def __init__(self, *params, issubclassed=False):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
if not issubclassed:
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map',
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
# Additional Attributes #
|
||||
#######################################################
|
||||
self.map_shape = self.dataset.map_shapes_max
|
||||
self.trajectory_features = 4
|
||||
self.res_net = self.hparams.model_param.use_res_net
|
||||
self.lat_dim = self.hparams.model_param.lat_dim
|
||||
self.feature_dim = self.lat_dim * 10
|
||||
########################################################
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
#
|
||||
# Utils
|
||||
self.activation = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
#
|
||||
# Map Encoder
|
||||
self.enc_conv_0 = ConvModule(self.map_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
|
||||
conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
|
||||
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
|
||||
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
|
||||
conv_filters=self.hparams.model_param.filters[2],
|
||||
use_norm=self.hparams.model_param.use_norm,
|
||||
use_bias=self.hparams.model_param.use_bias)
|
||||
|
||||
# Trajectory Encoder
|
||||
self.env_gru_1 = nn.GRU(input_size=self.trajectory_features, hidden_size=self.feature_dim,
|
||||
num_layers=3, batch_first=True)
|
||||
|
||||
self.enc_flat = Flatten(self.enc_conv_3b.shape)
|
||||
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
|
||||
|
||||
#
|
||||
# Mixed Encoder
|
||||
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
|
||||
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
|
||||
|
||||
#
|
||||
# Variational Bottleneck
|
||||
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
|
||||
|
||||
#
|
||||
# Alternative Generator
|
||||
self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
|
||||
|
||||
self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
|
||||
|
||||
self.gen_gru_x = nn.GRU(None, None, batch_first=True)
|
||||
|
||||
|
||||
|
||||
def forward(self, batch_x):
|
||||
#
|
||||
# Encode
|
||||
z, mu, logvar = self.encode(batch_x)
|
||||
|
||||
#
|
||||
# Generate
|
||||
alt_tensor = self.generate(z)
|
||||
return alt_tensor, z, mu, logvar
|
||||
|
||||
@staticmethod
|
||||
def reparameterize(mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def encode(self, batch_x):
|
||||
combined_tensor = self.enc_conv_0(batch_x)
|
||||
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_1a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_1b(combined_tensor)
|
||||
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_2a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_2b(combined_tensor)
|
||||
combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
|
||||
combined_tensor = self.enc_conv_3a(combined_tensor)
|
||||
combined_tensor = self.enc_conv_3b(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_flat(combined_tensor)
|
||||
combined_tensor = self.enc_lin_1(combined_tensor)
|
||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
combined_tensor = self.enc_lin_2(combined_tensor)
|
||||
combined_tensor = self.enc_norm(combined_tensor)
|
||||
combined_tensor = self.activation(combined_tensor)
|
||||
|
||||
#
|
||||
# Parameter and Sampling
|
||||
mu = self.mu(combined_tensor)
|
||||
logvar = self.logvar(combined_tensor)
|
||||
z = self.reparameterize(mu, logvar)
|
||||
return z, mu, logvar
|
||||
|
||||
def generate(self, z):
|
||||
alt_tensor = self.gen_lin_1(z)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.gen_lin_2(alt_tensor)
|
||||
alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.reshape_to_last_conv(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_1a(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_1b(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_2a(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_2b(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_3a(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_3b(alt_tensor)
|
||||
alt_tensor = self.gen_deconv_out(alt_tensor)
|
||||
# alt_tensor = self.activation(alt_tensor)
|
||||
alt_tensor = self.sigmoid(alt_tensor)
|
||||
return alt_tensor
|
||||
|
||||
def generate_random(self, n=6):
|
||||
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
|
||||
|
||||
trajectories = [x.get_random_trajectory() for x in maps]
|
||||
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
|
||||
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
|
||||
trajectories = self._move_to_model_device(torch.stack(trajectories))
|
||||
|
||||
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
|
||||
maps = self._move_to_model_device(torch.stack(maps))
|
||||
|
||||
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
||||
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
|
||||
|
||||
|
||||
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
|
||||
name = 'CNNRouteGeneratorDiscriminated'
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
discriminated_bce_loss=discriminated_bce_loss,
|
||||
kld_loss=kld_loss)
|
||||
)
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
|
||||
plt.clf()
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
if self._disc is None:
|
||||
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||
return self._disc
|
||||
|
||||
def set_discriminator(self, disc_model):
|
||||
if self._disc is not None:
|
||||
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||
self._disc = disc_model
|
||||
|
||||
def __init__(self, *params):
|
||||
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||
|
||||
self._disc = None
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||
length=self.hparams.data_param.dataset_length, normalized=True)
|
Reference in New Issue
Block a user