fig clf inserted and not resize on kld

This commit is contained in:
Steffen Illium 2020-03-13 21:52:33 +01:00
parent bb47e07566
commit 2305c8e54a
33 changed files with 403 additions and 279 deletions

1
.gitignore vendored
View File

@ -3,6 +3,7 @@
# User-specific stuff
.idea/**
res/**
# CMake
cmake-build-*/

29
datasets/mnist.py Normal file
View File

@ -0,0 +1,29 @@
from torchvision.datasets import MNIST
import numpy as np
class MyMNIST(MNIST):
@property
def map_shapes_max(self):
return np.asarray(self.test_dataset[0][0]).shape
def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__('res', train=False, download=True)
pass
def __getitem__(self, item):
image = super(MyMNIST, self).__getitem__(item)
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
@property
def train_dataset(self):
return self.__class__('res', train=True, download=True)
@property
def test_dataset(self):
return self.__class__('res', train=False, download=True)
@property
def val_dataset(self):
return self.__class__('res', train=False, download=True)

View File

@ -1,6 +1,9 @@
import shelve
from collections import defaultdict
from pathlib import Path
from typing import Union, List
from typing import Union
from torchvision.transforms import Normalize
import multiprocessing as mp
@ -24,16 +27,17 @@ class TrajDataShelve(Dataset):
return self[0][0].shape
def __init__(self, file_path, **kwargs):
assert Path(file_path).exists()
super(TrajDataShelve, self).__init__()
self._mutex = mp.Lock()
self.file_path = str(file_path)
def __len__(self):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
length = len(d)
self._mutex.release()
d.close()
self._mutex.release()
return length
def seed(self):
@ -43,12 +47,20 @@ class TrajDataShelve(Dataset):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
sample = d[str(item)]
self._mutex.release()
d.close()
self._mutex.release()
return sample
class TrajDataset(Dataset):
@property
def _last_label_init(self):
d = defaultdict(lambda: -1)
d['generator_hom_all_in_map'] = V.ALTERNATIVE
d['generator_alt_all_in_map'] = V.HOMOTOPIC
return d[self.mode]
@property
def map_shape(self):
return self.map.as_array.shape
@ -57,17 +69,18 @@ class TrajDataset(Dataset):
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
**kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
'classifier_all_in_map']
self.normalized = normalized
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map',
'ae_no_label_in_map',
'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map']
self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x
self.preserve_equal_samples = preserve_equal_samples
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
self.last_label = self._last_label_init
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
def __len__(self):
return self._len
@ -82,6 +95,7 @@ class TrajDataset(Dataset):
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, trajectory_space), label
# Produce an alternative.
while True:
trajectory = self.map.get_random_trajectory()
alternative = self.map.generate_alternative(trajectory)
@ -91,18 +105,19 @@ class TrajDataset(Dataset):
else:
break
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode]
if 'in_map' in self.mode.lower():
map_array = self.map.as_array
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
label_as_array = np.full_like(map_array, label)
if self.normalized:
map_array = map_array / V.WHITE
trajectory = trajectory / V.WHITE
alternative = alternative / V.WHITE
if self.mode == 'generator_all_in_map':
return np.concatenate((map_array, trajectory, label_as_array)), alternative
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
return np.sum((map_array, trajectory, alternative), axis=0), 0
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
return np.concatenate((map_array, trajectory)), alternative
elif self.mode == 'classifier_all_in_map':
return np.concatenate((map_array, trajectory, alternative)), label
@ -119,13 +134,13 @@ class TrajDataset(Dataset):
class TrajData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._train_dataset.datasets]
return [dataset.map_shape for dataset in self.train_dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if '_all_in_map' in self.mode:
if '_all_in_map' in self.mode and not self.preprocessed:
shape_list[0] += 2
return shape_list
@ -139,14 +154,13 @@ class TrajData(object):
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._test_dataset = self._load_datasets('train')
self._val_dataset = self._load_datasets('val')
self._train_dataset = self._load_datasets('test')
self.test_dataset = self._load_datasets('test')
self.val_dataset = self._load_datasets('val')
self.train_dataset = self._load_datasets('train')
def _load_datasets(self, dataset_type=''):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files)) or 1
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
@ -156,10 +170,11 @@ class TrajData(object):
preprocessed_map_names = [p.name for p in preprocessed_map_files]
datasets = []
for map_file in map_files:
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
equal_split = int(self.length // len(map_files)) or 5
new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik'
if dataset_type != 'train':
equal_split *= 0.01
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
equal_split = max(int(equal_split * 0.01), 10)
if not new_pik_name in preprocessed_map_names:
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
@ -168,6 +183,9 @@ class TrajData(object):
dataset = TrajDataShelve(map_file.parent / new_pik_name)
datasets.append(dataset)
return ConcatDataset(datasets)
# Set the equal split so that all maps are visited with the same frequency
equal_split = int(self.length // len(map_files)) or 5
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
@ -185,29 +203,14 @@ class TrajData(object):
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
assert str(file_path).endswith('.pik')
processes = mp.cpu_count() - 1
mutex = mp.Lock()
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'):
sample = traj_dataset[i]
mutex.acquire()
write_to_shelve(file_path, sample)
mutex.release()
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
sample = result_obj.get()
mutex.acquire()
write_to_shelve(file_path, sample)
mutex.release()
print(f'{n} samples sucessfully dumped to "{file_path}"!')
@property
def train_dataset(self):
return self._train_dataset
@property
def val_dataset(self):
return self._val_dataset
@property
def test_dataset(self):
return self._test_dataset
print(f'{n} samples successfully dumped to "{file_path}"!')
def get_datasets(self):
return self._train_dataset, self._val_dataset, self._test_dataset

View File

@ -1,19 +1,22 @@
from random import choices, seed
import numpy as np
import torch
from functools import reduce
from operator import mul
from random import choices, choice
import torch
from torch import nn
from torch.optim import Adam
from torchvision.datasets import MNIST
from datasets.mnist import MyMNIST
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
import lib.variables as V
from lib.visualization.generator_eval import GeneratorVisualizer
class CNNRouteGeneratorModel(LightningBaseModule):
@ -24,48 +27,71 @@ class CNNRouteGeneratorModel(LightningBaseModule):
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, alternative = batch_xy
batch_x, target = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
element_wise_loss = self.criterion(generated_alternative, alternative)
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
target = batch_x if 'ae' in self.hparams.data_param.mode else target
element_wise_loss = self.criterion(generated_alternative, target)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
if 'vae' in self.hparams.data_param.mode:
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size
loss = kld_loss + element_wise_loss
loss = kld_loss + element_wise_loss
else:
loss = element_wise_loss
kld_loss = 0
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, _ = batch_xy
map_array = batch_x[:, 0].unsqueeze(1)
trajectory = batch_x[:, 1].unsqueeze(1)
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
if 'vae' in self.hparams.data_param.mode:
z, mu, logvar = self.encode(batch_x)
else:
z = self.encode(batch_x)
mu, logvar = z, z
_, mu, _ = self.encode(batch_x)
generated_alternative = self.generate(mu)
return dict(maps=map_array, trajectories=trajectory, batch_nb=batch_nb, labels=labels,
generated_alternative=generated_alternative, pred_label=-1)
return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar)
if 'hom' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC)
elif 'alt' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE)
elif 'vae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY)
elif 'ae' in self.hparams.data_param.mode:
labels = torch.full((batch_x.shape[0], 1), V.ANY)
else:
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
return_dict.update(labels=self._move_to_model_device(labels))
return return_dict
def _test_val_epoch_end(self, outputs, test=False):
val_restul_dict = self.generate_random()
plt.close('all')
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(**val_restul_dict)
fig = g.draw()
g = GeneratorVisualizer(choice(outputs))
fig = g.draw_io_bundle()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
fig = g.draw_latent()
self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step)
plt.clf()
return dict(epoch=self.current_epoch)
def on_epoch_start(self):
self.dataset.seed(self.logger.version)
# self.dataset.seed(self.logger.version)
# torch.random.manual_seed(self.logger.version)
# np.random.seed(self.logger.version)
pass
def validation_step(self, *args):
return self._test_val_step(*args)
@ -82,19 +108,23 @@ class CNNRouteGeneratorModel(LightningBaseModule):
def __init__(self, *params, issubclassed=False):
super(CNNRouteGeneratorModel, self).__init__(*params)
if not issubclassed:
if False:
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map',
self.dataset = TrajData(self.hparams.data_param.map_root,
mode=self.hparams.data_param.mode,
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length, normalized=True)
self.criterion = nn.MSELoss()
self.criterion = nn.MSELoss()
self.dataset = MyMNIST()
# Additional Attributes #
#######################################################
self.in_shape = self.dataset.map_shapes_max
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim * 10
self.feature_dim = self.lat_dim
self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0]
########################################################
# NN Nodes
@ -119,7 +149,7 @@ class CNNRouteGeneratorModel(LightningBaseModule):
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0,
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
@ -137,20 +167,8 @@ class CNNRouteGeneratorModel(LightningBaseModule):
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_flat = Flatten(self.enc_conv_3b.shape)
last_conv_shape = self.enc_conv_2b.shape
self.enc_flat = Flatten(last_conv_shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
#
@ -160,46 +178,43 @@ class CNNRouteGeneratorModel(LightningBaseModule):
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
if 'vae' in self.hparams.data_param.mode:
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
#
# Linear Bottleneck
else:
self.z = nn.Linear(self.feature_dim, self.lat_dim)
#
# Alternative Generator
self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape)
self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
# self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, self.enc_conv_3b.shape)
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape)
self.gen_deconv_1a = DeConvModule(self.enc_conv_3b.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=11, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_1b = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=9, conv_stride=2,
self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2],
conv_padding=1, conv_kernel=9, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1b.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2b = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=7, conv_stride=1,
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1],
conv_padding=1, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2b.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_3b = DeConvModule(self.gen_deconv_3a.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=4, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_3b.shape, 1, activation=None,
self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, batch_x):
#
# Encode
z, mu, logvar = self.encode(batch_x)
if 'vae' in self.hparams.data_param.mode:
z, mu, logvar = self.encode(batch_x)
else:
z = self.encode(batch_x)
mu, logvar = z, z
#
# Generate
@ -220,148 +235,46 @@ class CNNRouteGeneratorModel(LightningBaseModule):
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_2b(combined_tensor)
combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_3a(combined_tensor)
combined_tensor = self.enc_conv_3b(combined_tensor)
# combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
# combined_tensor = self.enc_conv_3a(combined_tensor)
# combined_tensor = self.enc_conv_3b(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
#
# Variational
# Parameter and Sampling
mu = self.mu(combined_tensor)
logvar = self.logvar(combined_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
if 'vae' in self.hparams.data_param.mode:
mu = self.mu(combined_tensor)
logvar = self.logvar(combined_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
else:
#
# Linear Bottleneck
z = self.z(combined_tensor)
return z
def generate(self, z):
alt_tensor = self.gen_lin_1(z)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.gen_lin_2(alt_tensor)
alt_tensor = self.activation(alt_tensor)
# alt_tensor = self.gen_lin_2(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_last_conv(alt_tensor)
alt_tensor = self.gen_deconv_1a(alt_tensor)
alt_tensor = self.gen_deconv_1b(alt_tensor)
alt_tensor = self.gen_deconv_2a(alt_tensor)
alt_tensor = self.gen_deconv_2b(alt_tensor)
alt_tensor = self.gen_deconv_3a(alt_tensor)
alt_tensor = self.gen_deconv_3b(alt_tensor)
# alt_tensor = self.gen_deconv_3a(alt_tensor)
# alt_tensor = self.gen_deconv_3b(alt_tensor)
alt_tensor = self.gen_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor)
# alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
def generate_random(self, n=12):
samples, alternatives = zip(*[self.dataset.test_dataset[choice]
for choice in choices(range(self.dataset.length), k=n)])
samples = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in samples]))
alternatives = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in alternatives]))
return self._test_val_step((samples, alternatives), -9999)
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
name = 'CNNRouteGeneratorDiscriminated'
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
raise NotImplementedError
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
self._disc = None
self.criterion = nn.BCELoss()
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
length=self.hparams.data_param.dataset_length, normalized=True)

View File

@ -0,0 +1,116 @@
from random import choices, seed
import numpy as np
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
name = 'CNNRouteGeneratorDiscriminated'
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
raise NotImplementedError
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
self._disc = None
self.criterion = nn.BCELoss()
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
length=self.hparams.data_param.dataset_length, normalized=True)

View File

@ -189,5 +189,5 @@ class MapStorage(UserDict):
)
for map_file in map_files:
current_map = Map().from_image(map_file, embedding_size=self.max_map_size)
current_map = Map.from_image(map_file, embedding_size=self.max_map_size)
self.__setitem__(map_file.name, current_map)

View File

@ -5,7 +5,9 @@ from collections import defaultdict
from configparser import ConfigParser
from pathlib import Path
from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated
from lib.models.generators.cnn import CNNRouteGeneratorModel
from lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated
from lib.models.homotopy_classification.cnn_based import ConvHomDetector
from lib.utils.model_io import ModelParameters
from lib.utils.transforms import AsArray

View File

@ -37,7 +37,7 @@ class Logger(LightningLoggerBase):
@property
def outpath(self):
# ToDo: Add further path modification such as dataset config etc.
return Path(self.config.train.outpath)
return Path(self.config.train.outpath) / self.config.data.mode
def __init__(self, config: Config):
"""

View File

@ -9,6 +9,7 @@ def write_to_shelve(file_path, value):
with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = str(len(f))
f[new_key] = value
f.close()
def load_from_shelve(file_path, key):

View File

@ -1,9 +1,15 @@
from pathlib import Path
_ROOT = Path('..')
# Labels for classes
HOMOTOPIC = 1
ALTERNATIVE = 0
ANY = -1
# Colors for img files
WHITE = 255
BLACK = 0
DPI = 100
# Variables for plotting
PADDING = 0.25
DPI = 50

View File

@ -1,53 +1,106 @@
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.cm as cmaps
from mpl_toolkits.axisartist.axes_grid import ImageGrid
from sklearn.cluster import Birch, DBSCAN, KMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import lib.variables as V
import numpy as np
class GeneratorVisualizer(object):
def __init__(self, **kwargs):
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self.alternatives = kwargs.get('generated_alternative')
self.labels = kwargs.get('labels')
self.trajectories = kwargs.get('trajectories')
self.maps = kwargs.get('maps')
def __init__(self, outputs, k=8):
d = defaultdict(list)
for key in outputs.keys():
try:
d[key] = outputs[key][:k].cpu().numpy()
except AttributeError:
d[key] = outputs[key][:k]
except TypeError:
self.batch_nb = outputs[key]
for key in d.keys():
self.__setattr__(key, d[key])
self._map_width, self._map_height = self.maps[0].squeeze().shape
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self._map_width, self._map_height = self.input.shape[1], self.input.shape[2]
self.column_dict_list = self._build_column_dict_list()
self._cols = len(self.column_dict_list)
self._rows = len(self.column_dict_list[0])
self.colormap = cmaps.tab20
def _build_column_dict_list(self):
trajectories = []
non_hom_alternatives = []
hom_alternatives = []
alternatives = []
for idx in range(self.alternatives.shape[0]):
image = (self.alternatives[idx]).cpu().numpy().squeeze()
label = self.labels[idx].item()
# Dirty and Quick hack incomming.
if label == V.HOMOTOPIC:
hom_alternatives.append(dict(image=image, label='Homotopic'))
non_hom_alternatives.append(None)
else:
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
hom_alternatives.append(None)
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
for idx in range(self.output.shape[0]):
image = (self.output[idx]).squeeze()
label = 'Homotopic' if self.labels[idx].item() == V.HOMOTOPIC else 'Alternative'
alternatives.append(dict(image=image, label=label))
for idx in range(len(alternatives)):
image = (self.input[idx]).squeeze()
label = 'original'
trajectories.append(dict(image=image, label=label))
return trajectories, hom_alternatives, non_hom_alternatives
return trajectories, alternatives
def draw(self):
padding = 0.25
additional_size = self._cols * padding + 3 * padding
width = (self._map_width * self._cols) / V.DPI + additional_size
height = (self._map_height * self._rows) / V.DPI + additional_size
@staticmethod
def cluster_data(data):
cluster = Birch()
labels = cluster.fit_predict(data)
return labels
def draw_latent(self):
plt.close('all')
clusterer = KMeans(10)
try:
labels = clusterer.fit_predict(self.logvar)
except ValueError:
fig = plt.figure()
return fig
if self.z.shape[-1] > 2:
fig, axs = plt.subplots(ncols=2, nrows=1)
transformers = [TSNE(2), PCA(2)]
for idx, transformer in enumerate(transformers):
transformed = transformer.fit_transform(self.z)
colored = self.colormap(labels)
ax = axs[idx]
ax.scatter(x=transformed[:, 0], y=transformed[:, 1], c=colored)
ax.set_title(transformer.__class__.__name__)
ax.set_xlim(np.min(transformed[:, 0])*1.1, np.max(transformed[:, 0]*1.1))
ax.set_ylim(np.min(transformed[:, 1]*1.1), np.max(transformed[:, 1]*1.1))
elif self.z.shape[-1] == 2:
fig, axs = plt.subplots()
# TODO: Build transformation for lat_dim_size >= 3
print('All Predictions sucesfully Gathered and Shaped ')
axs.set_xlim(np.min(self.z[:, 0]), np.max(self.z[:, 0]))
axs.set_ylim(np.min(self.z[:, 1]), np.max(self.z[:, 1]))
# ToDo: Insert Normalization
colored = self.colormap(labels)
plt.scatter(self.z[:, 0], self.z[:, 1], c=colored)
else:
raise NotImplementedError("Latent Dimensions can not be one-dimensional (yet).")
return fig
def draw_io_bundle(self):
width, height = self._cols * 5, self._rows * 5
additional_size = self._cols * V.PADDING + 3 * V.PADDING
# width = (self._map_width * self._cols) / V.DPI + additional_size
# height = (self._map_height * self._rows) / V.DPI + additional_size
fig = plt.figure(figsize=(width, height), dpi=V.DPI)
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(self._rows, self._cols),
axes_pad=padding, # pad between axes in inch.
axes_pad=V.PADDING, # pad between axes in inch.
)
for idx in range(len(grid.axes_all)):

16
main.py
View File

@ -33,12 +33,13 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="")
# Data Parameters
main_arg_parser.add_argument("--data_worker", type=int, default=10, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="")
main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="")
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_mode", type=str, default='ae_no_label_in_map', help="")
# Transformations
main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="")
@ -46,7 +47,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa
# Transformations
main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="")
main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=20, help="")
main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="")
main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="")
main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="")
main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="")
@ -54,9 +55,9 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0
# Model
main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="")
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 32]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
@ -101,7 +102,7 @@ def run_lightning_loop(config_obj):
model.init_weights(torch.nn.init.xavier_normal_)
if model.name == 'CNNRouteGeneratorDiscriminated':
# ToDo: Make this dependent on the used seed
path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0')
path = logger.outpath / 'classifier_cnn' / 'version_0'
disc_model = SavedLightningModels.load_checkpoint(path).restore()
model.set_discriminator(disc_model)
@ -111,13 +112,12 @@ def run_lightning_loop(config_obj):
show_progress_bar=True,
weights_save_path=logger.log_dir,
gpus=[0] if torch.cuda.is_available() else None,
check_val_every_n_epoch=1,
num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
check_val_every_n_epoch=10,
# num_sanity_val_steps=config_obj.train.num_sanity_val_steps,
# row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
val_percent_check=0.025,
fast_dev_run=config_obj.main.debug,
early_stop_callback=None
)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 831 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB