Variational Generator
This commit is contained in:
@@ -5,8 +5,10 @@ from typing import Union, List
|
|||||||
import torch
|
import torch
|
||||||
from random import choice
|
from random import choice
|
||||||
from torch.utils.data import ConcatDataset, Dataset
|
from torch.utils.data import ConcatDataset, Dataset
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
|
import lib.variables as V
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
@@ -36,13 +38,10 @@ class TrajDataset(Dataset):
|
|||||||
|
|
||||||
if self.mode.lower() == 'just_route':
|
if self.mode.lower() == 'just_route':
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
|
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
||||||
label = choice([0, 1])
|
label = choice([0, 1])
|
||||||
blank_trajectory_space = torch.zeros(self.map.shape)
|
|
||||||
for index in trajectory.vertices:
|
|
||||||
blank_trajectory_space[index] = 1
|
|
||||||
|
|
||||||
map_array = torch.as_tensor(self.map.as_array).float()
|
map_array = torch.as_tensor(self.map.as_array).float()
|
||||||
return (map_array, blank_trajectory_space), label
|
return (map_array, trajectory_space), label
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
@@ -55,13 +54,13 @@ class TrajDataset(Dataset):
|
|||||||
|
|
||||||
self.last_label = label
|
self.last_label = label
|
||||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||||
map_array = torch.as_tensor(self.map.as_array).float()
|
map_array = self.map.as_array
|
||||||
|
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||||
|
alternative = alternative.draw_in_array(self.map_shape)
|
||||||
if self.mode == 'separated_arrays':
|
if self.mode == 'separated_arrays':
|
||||||
return (map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(), int(label)), \
|
return (map_array, trajectory, label), alternative
|
||||||
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float()
|
|
||||||
else:
|
else:
|
||||||
return torch.cat((map_array, torch.as_tensor(trajectory.draw_in_array(self.map_shape)).float(),
|
return np.concatenate((map_array, trajectory, alternative)), label
|
||||||
torch.as_tensor(alternative.draw_in_array(self.map_shape)).float())), int(label)
|
|
||||||
|
|
||||||
elif self.mode == 'vectors':
|
elif self.mode == 'vectors':
|
||||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||||
|
|||||||
@@ -25,59 +25,33 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
|
||||||
|
|
||||||
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
batch_x, label = batch_xy
|
batch_x, alternative = batch_xy
|
||||||
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
|
mse_loss = self.criterion(generated_alternative, alternative)
|
||||||
map_array, trajectory = batch_x
|
|
||||||
|
|
||||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
|
||||||
pred_label = self.discriminator(map_stack)
|
|
||||||
|
|
||||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
|
||||||
|
|
||||||
# see Appendix B from VAE paper:
|
# see Appendix B from VAE paper:
|
||||||
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
# https://arxiv.org/abs/1312.6114
|
# https://arxiv.org/abs/1312.6114
|
||||||
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
# Dimensional Resizing
|
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
|
||||||
kld_loss /= reduce(mul, self.in_shape)
|
kld_loss /= reduce(mul, self.in_shape)
|
||||||
|
|
||||||
loss = (kld_loss + discriminated_bce_loss) / 2
|
loss = (kld_loss + mse_loss) / 2
|
||||||
return dict(loss=loss, log=dict(loss=loss,
|
return dict(loss=loss, log=dict(loss=loss, mse_loss=mse_loss, kld_loss=kld_loss))
|
||||||
discriminated_bce_loss=discriminated_bce_loss,
|
|
||||||
kld_loss=kld_loss)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _test_val_step(self, batch_xy, batch_nb, *args):
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||||
batch_x, label = batch_xy
|
batch_x, alternative = batch_xy
|
||||||
|
map_array, trajectory, label = batch_x
|
||||||
|
|
||||||
generated_alternative, z, mu, logvar = self(batch_x + [label, ])
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
map_array, trajectory = batch_x
|
|
||||||
|
|
||||||
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
|
||||||
pred_label = self.discriminator(map_stack)
|
|
||||||
|
|
||||||
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
|
||||||
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
|
||||||
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
|
||||||
|
|
||||||
def validation_step(self, *args):
|
|
||||||
return self._test_val_step(*args)
|
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs: list):
|
|
||||||
return self._test_val_epoch_end(outputs)
|
|
||||||
|
|
||||||
def _test_val_epoch_end(self, outputs, test=False):
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
evaluation = ROCEvaluation(plot_roc=True)
|
|
||||||
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
|
||||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
|
||||||
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
|
||||||
|
|
||||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
|
||||||
if test:
|
if test:
|
||||||
# self.logger.log_metrics(score_dict)
|
|
||||||
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
||||||
plt.clf()
|
plt.clf()
|
||||||
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
@@ -87,7 +61,13 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
fig = g.draw()
|
fig = g.draw()
|
||||||
self.logger.log_image(f'{self.name}_Output', fig)
|
self.logger.log_image(f'{self.name}_Output', fig)
|
||||||
|
|
||||||
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
return dict(epoch=self.current_epoch)
|
||||||
|
|
||||||
|
def validation_step(self, *args):
|
||||||
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs: list):
|
||||||
|
return self._test_val_epoch_end(outputs)
|
||||||
|
|
||||||
def test_step(self, *args):
|
def test_step(self, *args):
|
||||||
return self._test_val_step(*args)
|
return self._test_val_step(*args)
|
||||||
@@ -95,31 +75,20 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
return self._test_val_epoch_end(outputs, test=True)
|
return self._test_val_epoch_end(outputs, test=True)
|
||||||
|
|
||||||
|
def __init__(self, *params, issubclassed=False):
|
||||||
@property
|
|
||||||
def discriminator(self):
|
|
||||||
if self._disc is None:
|
|
||||||
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
|
||||||
return self._disc
|
|
||||||
|
|
||||||
def set_discriminator(self, disc_model):
|
|
||||||
if self._disc is not None:
|
|
||||||
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
|
||||||
self._disc = disc_model
|
|
||||||
|
|
||||||
def __init__(self, *params):
|
|
||||||
super(CNNRouteGeneratorModel, self).__init__(*params)
|
super(CNNRouteGeneratorModel, self).__init__(*params)
|
||||||
|
|
||||||
# Dataset
|
if not issubclassed:
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
# Dataset
|
||||||
length=self.hparams.data_param.dataset_length)
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
|
||||||
|
length=self.hparams.data_param.dataset_length)
|
||||||
|
self.criterion = nn.MSELoss()
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.in_shape = self.dataset.map_shapes_max
|
self.in_shape = self.dataset.map_shapes_max
|
||||||
# Todo: Better naming and size in Parameters
|
# Todo: Better naming and size in Parameters
|
||||||
self.feature_dim = 10
|
self.feature_dim = 10
|
||||||
self.lat_dim = self.feature_dim + self.feature_dim + 1
|
self.lat_dim = self.feature_dim + self.feature_dim + 1
|
||||||
self._disc = None
|
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
###################################################
|
###################################################
|
||||||
@@ -127,7 +96,6 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
# Utils
|
# Utils
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
self.criterion = nn.MSELoss()
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Map Encoder
|
# Map Encoder
|
||||||
@@ -222,7 +190,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
alt_tensor = self.alt_deconv_2(alt_tensor)
|
alt_tensor = self.alt_deconv_2(alt_tensor)
|
||||||
alt_tensor = self.alt_deconv_3(alt_tensor)
|
alt_tensor = self.alt_deconv_3(alt_tensor)
|
||||||
alt_tensor = self.alt_deconv_out(alt_tensor)
|
alt_tensor = self.alt_deconv_out(alt_tensor)
|
||||||
alt_tensor = self.sigmoid(alt_tensor)
|
# alt_tensor = self.sigmoid(alt_tensor)
|
||||||
return alt_tensor
|
return alt_tensor
|
||||||
|
|
||||||
def encode(self, map_array, trajectory, label):
|
def encode(self, map_array, trajectory, label):
|
||||||
@@ -266,4 +234,100 @@ class CNNRouteGeneratorModel(LightningBaseModule):
|
|||||||
maps = self._move_to_model_device(torch.stack(maps))
|
maps = self._move_to_model_device(torch.stack(maps))
|
||||||
|
|
||||||
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
|
||||||
return maps, trajectories, labels, self._test_val_step(([maps, trajectories], labels), -9999)
|
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
|
||||||
|
|
||||||
|
|
||||||
|
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
|
||||||
|
|
||||||
|
name = 'CNNRouteGeneratorDiscriminated'
|
||||||
|
|
||||||
|
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
|
||||||
|
batch_x, label = batch_xy
|
||||||
|
|
||||||
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
|
map_array, trajectory = batch_x
|
||||||
|
|
||||||
|
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||||
|
pred_label = self.discriminator(map_stack)
|
||||||
|
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||||
|
|
||||||
|
# see Appendix B from VAE paper:
|
||||||
|
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
|
||||||
|
# https://arxiv.org/abs/1312.6114
|
||||||
|
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
|
||||||
|
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
# Dimensional Resizing
|
||||||
|
kld_loss /= reduce(mul, self.in_shape)
|
||||||
|
|
||||||
|
loss = (kld_loss + discriminated_bce_loss) / 2
|
||||||
|
return dict(loss=loss, log=dict(loss=loss,
|
||||||
|
discriminated_bce_loss=discriminated_bce_loss,
|
||||||
|
kld_loss=kld_loss)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _test_val_step(self, batch_xy, batch_nb, *args):
|
||||||
|
batch_x, label = batch_xy
|
||||||
|
|
||||||
|
generated_alternative, z, mu, logvar = self(batch_x)
|
||||||
|
map_array, trajectory = batch_x
|
||||||
|
|
||||||
|
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
|
||||||
|
pred_label = self.discriminator(map_stack)
|
||||||
|
|
||||||
|
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
|
||||||
|
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
|
||||||
|
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
|
||||||
|
|
||||||
|
def validation_step(self, *args):
|
||||||
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs: list):
|
||||||
|
return self._test_val_epoch_end(outputs)
|
||||||
|
|
||||||
|
def _test_val_epoch_end(self, outputs, test=False):
|
||||||
|
evaluation = ROCEvaluation(plot_roc=True)
|
||||||
|
pred_label = torch.cat([x['pred_label'] for x in outputs])
|
||||||
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||||
|
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
|
||||||
|
|
||||||
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||||
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
|
||||||
|
if test:
|
||||||
|
# self.logger.log_metrics(score_dict)
|
||||||
|
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf())
|
||||||
|
plt.clf()
|
||||||
|
maps, trajectories, labels, val_restul_dict = self.generate_random()
|
||||||
|
|
||||||
|
from lib.visualization.generator_eval import GeneratorVisualizer
|
||||||
|
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
|
||||||
|
fig = g.draw()
|
||||||
|
self.logger.log_image(f'{self.name}_Output', fig)
|
||||||
|
|
||||||
|
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
|
||||||
|
|
||||||
|
def test_step(self, *args):
|
||||||
|
return self._test_val_step(*args)
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
return self._test_val_epoch_end(outputs, test=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def discriminator(self):
|
||||||
|
if self._disc is None:
|
||||||
|
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
|
||||||
|
return self._disc
|
||||||
|
|
||||||
|
def set_discriminator(self, disc_model):
|
||||||
|
if self._disc is not None:
|
||||||
|
raise RuntimeError('Discriminator has already been set... What are trying to do?')
|
||||||
|
self._disc = disc_model
|
||||||
|
|
||||||
|
def __init__(self, *params):
|
||||||
|
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
|
||||||
|
|
||||||
|
self._disc = None
|
||||||
|
|
||||||
|
self.criterion = nn.BCELoss()
|
||||||
|
|
||||||
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
|
||||||
|
length=self.hparams.data_param.dataset_length)
|
||||||
|
|||||||
@@ -32,24 +32,35 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
pred_y = self(batch_x)
|
pred_y = self(batch_x)
|
||||||
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
||||||
|
|
||||||
|
def validation_step(self, batch_xy, batch_nb, **kwargs):
|
||||||
|
batch_x, batch_y = batch_xy
|
||||||
|
pred_y = self(batch_x)
|
||||||
|
return dict(prediction=pred_y, label=batch_y, batch_nb=batch_nb)
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
evaluation = ROCEvaluation(plot_roc=True)
|
return self._val_test_end(outputs)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs: list):
|
||||||
|
return self._val_test_end(outputs)
|
||||||
|
|
||||||
|
def _val_test_end(self, outputs, test=True):
|
||||||
|
evaluation = ROCEvaluation(plot_roc=True if test else False)
|
||||||
predictions = torch.cat([x['prediction'] for x in outputs])
|
predictions = torch.cat([x['prediction'] for x in outputs])
|
||||||
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
|
||||||
|
|
||||||
# Sci-py call ROC eval call is eval(true_label, prediction)
|
# Sci-py call ROC eval call is eval(true_label, prediction)
|
||||||
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy(), )
|
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), predictions.cpu().numpy())
|
||||||
score_dict = dict(roc_auc=roc_auc)
|
|
||||||
# self.logger.log_metrics(score_dict)
|
# self.logger.log_metrics(score_dict)
|
||||||
self.logger.log_image(f'{self.name}', plt.gcf())
|
if test:
|
||||||
|
self.logger.log_image(f'{self.name}', plt.gcf())
|
||||||
|
|
||||||
return dict(log=score_dict)
|
return dict(score=roc_auc, log=dict(roc_auc=roc_auc))
|
||||||
|
|
||||||
def __init__(self, hparams):
|
def __init__(self, hparams):
|
||||||
super(ConvHomDetector, self).__init__(hparams)
|
super(ConvHomDetector, self).__init__(hparams)
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map')
|
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map', )
|
||||||
|
|
||||||
# Additional Attributes
|
# Additional Attributes
|
||||||
self.map_shape = self.dataset.map_shapes_max
|
self.map_shape = self.dataset.map_shapes_max
|
||||||
@@ -59,6 +70,7 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
assert len(self.in_shape) == 3, f'Image or map shape has to have 3 dims, but had: {len(self.in_shape)}'
|
||||||
self.criterion = nn.BCELoss()
|
self.criterion = nn.BCELoss()
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
# NN Nodes
|
# NN Nodes
|
||||||
# ============================
|
# ============================
|
||||||
@@ -100,6 +112,7 @@ class ConvHomDetector(LightningBaseModule):
|
|||||||
tensor = self.map_conv_3(tensor)
|
tensor = self.map_conv_3(tensor)
|
||||||
tensor = self.flatten(tensor)
|
tensor = self.flatten(tensor)
|
||||||
tensor = self.linear(tensor)
|
tensor = self.linear(tensor)
|
||||||
|
tensor = self.relu(tensor)
|
||||||
tensor = self.classifier(tensor)
|
tensor = self.classifier(tensor)
|
||||||
tensor = self.sigmoid(tensor)
|
tensor = self.sigmoid(tensor)
|
||||||
return tensor
|
return tensor
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ConvModule(nn.Module):
|
|||||||
output = self(x)
|
output = self(x)
|
||||||
return output.shape[1:]
|
return output.shape[1:]
|
||||||
|
|
||||||
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=True,
|
def __init__(self, in_shape, activation: nn.Module = nn.ELU, pooling_size=None, use_bias=True, use_norm=False,
|
||||||
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
|
dropout: Union[int, float] = 0, conv_class=nn.Conv2d,
|
||||||
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
conv_filters=64, conv_kernel=5, conv_stride=1, conv_padding=0):
|
||||||
super(ConvModule, self).__init__()
|
super(ConvModule, self).__init__()
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class LightningBaseModule(pl.LightningModule, ABC):
|
|||||||
|
|
||||||
# Validation Dataloader
|
# Validation Dataloader
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(dataset=self.dataset.val_dataset, shuffle=False,
|
return DataLoader(dataset=self.dataset.val_dataset, shuffle=True,
|
||||||
batch_size=self.hparams.train_param.batch_size,
|
batch_size=self.hparams.train_param.batch_size,
|
||||||
num_workers=self.hparams.data_param.worker)
|
num_workers=self.hparams.data_param.worker)
|
||||||
|
|
||||||
|
|||||||
@@ -18,10 +18,6 @@ import lib.variables as V
|
|||||||
|
|
||||||
class Map(object):
|
class Map(object):
|
||||||
|
|
||||||
# This setting is for Img mode "L" aka GreyScale Image; values: 0-255
|
|
||||||
white = 255
|
|
||||||
black = 0
|
|
||||||
|
|
||||||
def __copy__(self):
|
def __copy__(self):
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
@@ -51,6 +47,7 @@ class Map(object):
|
|||||||
|
|
||||||
def __init__(self, name='', array_like_map_representation=None):
|
def __init__(self, name='', array_like_map_representation=None):
|
||||||
if array_like_map_representation is not None:
|
if array_like_map_representation is not None:
|
||||||
|
array_like_map_representation = array_like_map_representation.astype(np.float32)
|
||||||
if array_like_map_representation.ndim == 2:
|
if array_like_map_representation.ndim == 2:
|
||||||
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
|
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
|
||||||
assert array_like_map_representation.ndim == 3
|
assert array_like_map_representation.ndim == 3
|
||||||
@@ -70,7 +67,7 @@ class Map(object):
|
|||||||
|
|
||||||
# Check pixels for their color (determine if walkable)
|
# Check pixels for their color (determine if walkable)
|
||||||
for idx, value in np.ndenumerate(self.map_array):
|
for idx, value in np.ndenumerate(self.map_array):
|
||||||
if value != self.black:
|
if value != V.BLACK:
|
||||||
# IF walkable, add node
|
# IF walkable, add node
|
||||||
graph.add_node(idx, count=0)
|
graph.add_node(idx, count=0)
|
||||||
# Fully connect to all surrounding neighbors
|
# Fully connect to all surrounding neighbors
|
||||||
@@ -91,10 +88,9 @@ class Map(object):
|
|||||||
if image.mode != 'L':
|
if image.mode != 'L':
|
||||||
image = image.convert('L')
|
image = image.convert('L')
|
||||||
map_array = np.expand_dims(np.array(image), axis=0)
|
map_array = np.expand_dims(np.array(image), axis=0)
|
||||||
map_array = np.where(np.asarray(map_array) == cls.white, 1, 0)
|
|
||||||
if embedding_size:
|
if embedding_size:
|
||||||
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
|
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
|
||||||
embedding = np.zeros(embedding_size)
|
embedding = np.full(embedding_size, V.BLACK)
|
||||||
embedding[:map_array.shape[0], :map_array.shape[1], :map_array.shape[2]] = map_array
|
embedding[:map_array.shape[0], :map_array.shape[1], :map_array.shape[2]] = map_array
|
||||||
map_array = embedding
|
map_array = embedding
|
||||||
|
|
||||||
@@ -146,12 +142,15 @@ class Map(object):
|
|||||||
polyline = trajectory.xy_vertices
|
polyline = trajectory.xy_vertices
|
||||||
polyline.extend(reversed(other_trajectory.xy_vertices))
|
polyline.extend(reversed(other_trajectory.xy_vertices))
|
||||||
|
|
||||||
img = Image.new('L', (self.height, self.width), 0)
|
img = Image.new('L', (self.height, self.width), color=V.WHITE)
|
||||||
draw = ImageDraw.Draw(img)
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
draw.polygon(polyline, outline=1, fill=1)
|
draw.polygon(polyline, outline=V.BLACK, fill=V.BLACK)
|
||||||
|
|
||||||
a = (np.asarray(img) * np.where(self.as_2d_array == self.black, 1, 0)).sum()
|
binary_img = np.where(np.asarray(img).squeeze() == V.BLACK, 1, 0)
|
||||||
|
binary_map = np.where(self.as_2d_array == V.BLACK, 1, 0)
|
||||||
|
|
||||||
|
a = (binary_img * binary_map).sum()
|
||||||
|
|
||||||
if a:
|
if a:
|
||||||
return V.ALTERNATIVE # Non-Homotoph
|
return V.ALTERNATIVE # Non-Homotoph
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ class Trajectory(object):
|
|||||||
return list(zip(self._vertices[:-1], self._vertices[1:]))
|
return list(zip(self._vertices[:-1], self._vertices[1:]))
|
||||||
|
|
||||||
def draw_in_array(self, shape):
|
def draw_in_array(self, shape):
|
||||||
trajectory_space = np.zeros(shape)
|
trajectory_space = np.zeros(shape).astype(np.float32)
|
||||||
for index in self.vertices:
|
for index in self.vertices:
|
||||||
trajectory_space[index] = 1
|
trajectory_space[index] = V.WHITE
|
||||||
return trajectory_space
|
return trajectory_space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
|||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lib.models.generators.cnn import CNNRouteGeneratorModel
|
from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
|
||||||
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
|
||||||
from lib.utils.model_io import ModelParameters
|
from lib.utils.model_io import ModelParameters
|
||||||
|
|
||||||
@@ -28,7 +28,10 @@ class Config(ConfigParser):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def model_class(self):
|
def model_class(self):
|
||||||
model_dict = dict(classifier_cnn=ConvHomDetector, generator_cnn=CNNRouteGeneratorModel)
|
model_dict = dict(ConvHomDetector=ConvHomDetector,
|
||||||
|
CNNRouteGenerator=CNNRouteGeneratorModel,
|
||||||
|
CNNRouteGeneratorDiscriminated=CNNRouteGeneratorDiscriminated
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
return model_dict[self.get('model', 'type')]
|
return model_dict[self.get('model', 'type')]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
|
|||||||
@@ -3,3 +3,5 @@ _ROOT = Path('..')
|
|||||||
|
|
||||||
HOMOTOPIC = 1
|
HOMOTOPIC = 1
|
||||||
ALTERNATIVE = 0
|
ALTERNATIVE = 0
|
||||||
|
WHITE = 255
|
||||||
|
BLACK = 0
|
||||||
|
|||||||
10
main.py
10
main.py
@@ -48,7 +48,7 @@ main_arg_parser.add_argument("--train_batch_size", type=int, default=256, help="
|
|||||||
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
main_arg_parser.add_argument("--model_type", type=str, default="generator_cnn", help="")
|
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
|
||||||
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
main_arg_parser.add_argument("--model_activation", type=str, default="relu", help="")
|
||||||
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
main_arg_parser.add_argument("--model_filters", type=str, default="[32, 16, 4]", help="")
|
||||||
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
|
||||||
@@ -94,9 +94,9 @@ def run_lightning_loop(config_obj):
|
|||||||
# Init
|
# Init
|
||||||
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
model: LightningBaseModule = config_obj.model_class(config_obj.model_paramters)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
if model.name == 'CNNRouteGenerator':
|
if model.name == 'CNNRouteGeneratorDiscriminated':
|
||||||
# ToDo: Make this dependent on the used seed
|
# ToDo: Make this dependent on the used seed
|
||||||
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
|
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'trained')
|
||||||
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
disc_model = SavedLightningModels.load_checkpoint(path).restore()
|
||||||
model.set_discriminator(disc_model)
|
model.set_discriminator(disc_model)
|
||||||
|
|
||||||
@@ -112,7 +112,9 @@ def run_lightning_loop(config_obj):
|
|||||||
checkpoint_callback=checkpoint_callback,
|
checkpoint_callback=checkpoint_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
fast_dev_run=config_obj.main.debug,
|
fast_dev_run=config_obj.main.debug,
|
||||||
early_stop_callback=None
|
early_stop_callback=None,
|
||||||
|
val_percent_check=0.10,
|
||||||
|
num_sanity_val_steps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Train It
|
# Train It
|
||||||
|
|||||||
Reference in New Issue
Block a user