Variational Generator
This commit is contained in:
@ -25,59 +25,33 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
|
||||
batch_x, alternative = batch_xy
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
mse_loss = self.criterion(generated_alternative, alternative)
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
discriminated_bce_loss=discriminated_bce_loss,
|
||||
kld_loss=kld_loss)
|
||||
)
|
||||
loss = (kld_loss + mse_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss, mse_loss=mse_loss, kld_loss=kld_loss))
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, label = batch_xy
|
||||
batch_x, alternative = batch_xy
|
||||
map_array, trajectory, label = batch_x
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
|
||||
map_array, trajectory = batch_x
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
||||
plt.clf()
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
@ -87,7 +61,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig)
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
return dict(epoch=self.current_epoch)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
@ -95,31 +75,20 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
if self._disc is None:
|
||||
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||
return self._disc
|
||||
|
||||
def set_discriminator(self, disc_model):
|
||||
if self._disc is not None:
|
||||
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||
self._disc = disc_model
|
||||
|
||||
def __init__(self, *params):
|
||||
def __init__(self, *params, issubclassed=False):
|
||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
if not issubclassed:
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
# Additional Attributes
|
||||
self.in_shape = self.dataset.map_shapes_max
|
||||
# Todo: Better naming and size in Parameters
|
||||
self.feature_dim = 10
|
||||
self.lat_dim = self.feature_dim + self.feature_dim + 1
|
||||
self._disc = None
|
||||
|
||||
# NN Nodes
|
||||
###################################################
|
||||
@ -127,7 +96,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
# Utils
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
#
|
||||
# Map Encoder
|
||||
@ -222,7 +190,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
alt_tensor = self.alt_deconv_2(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_3(alt_tensor)
|
||||
alt_tensor = self.alt_deconv_out(alt_tensor)
|
||||
alt_tensor = self.sigmoid(alt_tensor)
|
||||
# alt_tensor = self.sigmoid(alt_tensor)
|
||||
return alt_tensor
|
||||
|
||||
def encode(self, map_array, trajectory, label):
|
||||
@ -266,4 +234,100 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
||||
maps = self._move_to_model_device(torch.stack(maps))
|
||||
|
||||
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
||||
return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999)
|
||||
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
|
||||
|
||||
|
||||
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||
|
||||
name = 'CNNRouteGeneratorDiscriminated'
|
||||
|
||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
|
||||
# see Appendix B from VAE paper:
|
||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||
# https://arxiv.org/abs/1312.6114
|
||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
# Dimensional Resizing
|
||||
kld_loss /= reduce(mul, self.in_shape)
|
||||
|
||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||
return dict(loss=loss, log=dict(loss=loss,
|
||||
discriminated_bce_loss=discriminated_bce_loss,
|
||||
kld_loss=kld_loss)
|
||||
)
|
||||
|
||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||
batch_x, label = batch_xy
|
||||
|
||||
generated_alternative, z, mu, logvar = self(batch_x)
|
||||
map_array, trajectory = batch_x
|
||||
|
||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||
pred_label = self.discriminator(map_stack)
|
||||
|
||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||
|
||||
def validation_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._test_val_epoch_end(outputs)
|
||||
|
||||
def _test_val_epoch_end(self, outputs, test=False):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||
if test:
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
||||
plt.clf()
|
||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||
|
||||
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||
fig = g.draw()
|
||||
self.logger.log_image(f'{self.name}_Output', fig)
|
||||
|
||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||
|
||||
def test_step(self, *args):
|
||||
return self._test_val_step(*args)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
return self._test_val_epoch_end(outputs, test=True)
|
||||
|
||||
@property
|
||||
def discriminator(self):
|
||||
if self._disc is None:
|
||||
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||
return self._disc
|
||||
|
||||
def set_discriminator(self, disc_model):
|
||||
if self._disc is not None:
|
||||
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||
self._disc = disc_model
|
||||
|
||||
def __init__(self, *params):
|
||||
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||
|
||||
self._disc = None
|
||||
|
||||
self.criterion = nn.BCELoss()
|
||||
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||
length=self.hparams.data_param.dataset_length)
|
||||
|
@ -32,24 +32,35 @@ class ConvHomDetector(LightningBaseModule):
|
||||
pred_y = self(batch_x)
|
||||
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
||||
|
||||
def validation_step(self, batch_xy, batch_nb, **kwargs):
|
||||
batch_x, batch_y = batch_xy
|
||||
pred_y = self(batch_x)
|
||||
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
evaluation = ROCEvaluation(plot_roc=True)
|
||||
return self._val_test_end(outputs)
|
||||
|
||||
def validation_epoch_end(self, outputs: list):
|
||||
return self._val_test_end(outputs)
|
||||
|
||||
def _val_test_end(self, outputs, test=True):
|
||||
evaluation = ROCEvaluation(plot_roc=True if test else False)
|
||||
predictions = torch.cat([x['prediction'] for x in outputs])
|
||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||
|
||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
|
||||
score_dict = dict(roc_auc=roc_auc)
|
||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy())
|
||||
# self.logger.log_metrics(score_dict)
|
||||
self.logger.log_image(f'{self.name}', plt.gcf())
|
||||
if test:
|
||||
self.logger.log_image(f'{self.name}', plt.gcf())
|
||||
|
||||
return dict(log=score_dict)
|
||||
return dict(score=roc_auc, log=dict(roc_auc=roc_auc))
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(ConvHomDetector, self).__init__(hparams)
|
||||
|
||||
# Dataset
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map')
|
||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map', )
|
||||
|
||||
# Additional Attributes
|
||||
self.map_shape = self.dataset.map_shapes_max
|
||||
@ -59,6 +70,7 @@ class ConvHomDetector(LightningBaseModule):
|
||||
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||
self.criterion = nn.BCELoss()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# NN Nodes
|
||||
# ============================
|
||||
@ -100,6 +112,7 @@ class ConvHomDetector(LightningBaseModule):
|
||||
tensor = self.map_conv_3(tensor)
|
||||
tensor = self.flatten(tensor)
|
||||
tensor = self.linear(tensor)
|
||||
tensor = self.relu(tensor)
|
||||
tensor = self.classifier(tensor)
|
||||
tensor = self.sigmoid(tensor)
|
||||
return tensor
|
||||
|
Reference in New Issue
Block a user