Offline Datasets res net optionality

This commit is contained in:
Si11ium 2020-03-12 18:32:23 +01:00
parent 2f99341cc3
commit bb47e07566
11 changed files with 638 additions and 140 deletions

View File

@ -6,7 +6,7 @@ import torch
from torch.utils.data import Dataset, ConcatDataset
from datasets.utils import DatasetMapping
from lib.modules.model_parts import Generator
from lib.preprocessing.generator import Generator
from lib.objects.map import Map

View File

@ -2,15 +2,50 @@ import shelve
from pathlib import Path
from typing import Union, List
import multiprocessing as mp
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
import numpy as np
from tqdm import tqdm
from lib.objects.map import Map
import lib.variables as V
from PIL import Image
from lib.utils.tools import write_to_shelve
class TrajDataShelve(Dataset):
@property
def map_shape(self):
return self[0][0].shape
def __init__(self, file_path, **kwargs):
super(TrajDataShelve, self).__init__()
self._mutex = mp.Lock()
self.file_path = str(file_path)
def __len__(self):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
length = len(d)
self._mutex.release()
return length
def seed(self):
pass
def __getitem__(self, item):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
sample = d[str(item)]
self._mutex.release()
return sample
class TrajDataset(Dataset):
@ -22,14 +57,15 @@ class TrajDataset(Dataset):
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
**kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
'classifier_all_in_map']
self.normalized = normalized
self.preserve_equal_samples = preserve_equal_samples
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = -1
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
@ -39,6 +75,7 @@ class TrajDataset(Dataset):
def __getitem__(self, item):
if self.mode.lower() == 'just_route':
raise NotImplementedError
trajectory = self.map.get_random_trajectory()
trajectory_space = trajectory.draw_in_array(self.map.shape)
label = choice([0, 1])
@ -54,37 +91,41 @@ class TrajDataset(Dataset):
else:
break
self.last_label = label
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
map_array = self.map.as_array
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
if self.mode == 'separated_arrays':
if self.normalized:
map_array = map_array / V.WHITE
trajectory = trajectory / V.WHITE
alternative = alternative / V.WHITE
return (map_array, trajectory, label), alternative
else:
label_as_array = np.full_like(map_array, label)
if self.normalized:
map_array = map_array / V.WHITE
trajectory = trajectory / V.WHITE
alternative = alternative / V.WHITE
if self.mode == 'generator_all_in_map':
return np.concatenate((map_array, trajectory, label_as_array)), alternative
elif self.mode == 'classifier_all_in_map':
return np.concatenate((map_array, trajectory, alternative)), label
elif self.mode == 'vectors':
elif self.mode == '_vectors':
raise NotImplementedError
return trajectory.vertices, alternative.vertices, label, self.mapname
else:
raise ValueError
raise ValueError(f'Mode was: {self.mode}')
def seed(self, seed):
self.map.seed(seed)
class TrajData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._dataset.datasets]
return [dataset.map_shape for dataset in self._train_dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if self.mode in ['separated_arrays', 'all_in_map']:
if '_all_in_map' in self.mode:
shape_list[0] += 2
return shape_list
@ -92,36 +133,81 @@ class TrajData(object):
def name(self):
return self.__class__.__name__
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
self.preprocessed = preprocessed
self.normalized = normalized
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._dataset = self._load_datasets()
self._test_dataset = self._load_datasets('train')
self._val_dataset = self._load_datasets('val')
self._train_dataset = self._load_datasets('test')
def _load_datasets(self, dataset_type=''):
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files)) or 1
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
if self.preprocessed:
preprocessed_map_files = list(self.maps_root.glob('*.pik'))
preprocessed_map_names = [p.name for p in preprocessed_map_files]
datasets = []
for map_file in map_files:
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
if dataset_type != 'train':
equal_split *= 0.01
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split)
dataset = TrajDataShelve(map_file.parent / new_pik_name)
datasets.append(dataset)
return ConcatDataset(datasets)
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
for map_file in map_files])
def kill_em_all(self):
for pik_file in self.maps_root.glob('*.pik'):
pik_file.unlink()
print(pik_file.name, ' was deleted.')
print('Done.')
def seed(self, seed):
for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]:
dataset.seed(seed)
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
assert str(file_path).endswith('.pik')
processes = mp.cpu_count() - 1
mutex = mp.Lock()
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
sample = result_obj.get()
mutex.acquire()
write_to_shelve(file_path, sample)
mutex.release()
print(f'{n} samples sucessfully dumped to "{file_path}"!')
@property
def train_dataset(self):
return self._dataset
return self._train_dataset
@property
def val_dataset(self):
return self._dataset
return self._val_dataset
@property
def test_dataset(self):
return self._dataset
return self._test_dataset
def get_datasets(self):
return self._dataset, self._dataset, self._dataset
return self._train_dataset, self._val_dataset, self._test_dataset

View File

@ -1,4 +1,5 @@
from random import choice
from random import choices, seed
import numpy as np
import torch
from functools import reduce
@ -36,28 +37,36 @@ class CNNRouteGeneratorModel(LightningBaseModule):
# kld_loss /= reduce(mul, self.in_shape)
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
loss = (kld_loss + element_wise_loss) / 2
loss = kld_loss + element_wise_loss
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, _ = batch_xy
map_array, trajectory, label = batch_x
map_array = batch_x[:, 0].unsqueeze(1)
trajectory = batch_x[:, 1].unsqueeze(1)
labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values
generated_alternative, z, mu, logvar = self(batch_x)
return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1)
_, mu, _ = self.encode(batch_x)
generated_alternative = self.generate(mu)
return dict(maps=map_array, trajectories=trajectory, batch_nb=batch_nb, labels=labels,
generated_alternative=generated_alternative, pred_label=-1)
def _test_val_epoch_end(self, outputs, test=False):
maps, trajectories, labels, val_restul_dict = self.generate_random()
val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
g = GeneratorVisualizer(**val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(epoch=self.current_epoch)
def on_epoch_start(self):
self.dataset.seed(self.logger.version)
# torch.random.manual_seed(self.logger.version)
# np.random.seed(self.logger.version)
def validation_step(self, *args):
return self._test_val_step(*args)
@ -75,14 +84,18 @@ class CNNRouteGeneratorModel(LightningBaseModule):
if not issubclassed:
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays',
self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map',
preprocessed=self.hparams.data_param.use_preprocessed,
length=self.hparams.data_param.dataset_length, normalized=True)
self.criterion = nn.MSELoss()
# Additional Attributes
# Additional Attributes #
#######################################################
self.in_shape = self.dataset.map_shapes_max
# Todo: Better naming and size in Parameters
self.feature_dim = self.hparams.model_param.lat_dim * 10
self.use_res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim * 10
########################################################
# NN Nodes
###################################################
@ -93,82 +106,100 @@ class CNNRouteGeneratorModel(LightningBaseModule):
#
# Map Encoder
self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=11, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.map_flat = Flatten(self.map_conv_3.shape)
self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim)
self.enc_flat = Flatten(self.enc_conv_3b.shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
#
# Mixed Encoder
self.mixed_lin = nn.Linear(self.feature_dim, self.feature_dim)
self.mixed_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim)
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
#
# Alternative Generator
self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
# Todo Fix This Hack!!!!
reshape_shape = (1, self.map_conv_3.shape[1], self.map_conv_3.shape[2])
self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, reshape_shape))
self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape)
self.reshape_to_last_conv = Flatten(self.enc_flat.shape, self.enc_conv_3b.shape)
self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=13, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None,
conv_padding=1, conv_kernel=3, conv_stride=1,
self.gen_deconv_1a = DeConvModule(self.enc_conv_3b.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=11, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_1b = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[2],
conv_padding=0, conv_kernel=9, conv_stride=2,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2a = DeConvModule(self.gen_deconv_1b.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_2b = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[1],
conv_padding=0, conv_kernel=7, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_3a = DeConvModule(self.gen_deconv_2b.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=5, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_3b = DeConvModule(self.gen_deconv_3a.shape, self.hparams.model_param.filters[0],
conv_padding=1, conv_kernel=4, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
self.gen_deconv_out = DeConvModule(self.gen_deconv_3b.shape, 1, activation=None,
conv_padding=0, conv_kernel=3, conv_stride=1,
use_norm=self.hparams.model_param.use_norm)
def forward(self, batch_x):
#
# Sorting the Input
map_array, trajectory, label = batch_x
#
# Encode
z, mu, logvar = self.encode(map_array, trajectory, label)
z, mu, logvar = self.encode(batch_x)
#
# Generate
@ -181,42 +212,26 @@ class CNNRouteGeneratorModel(LightningBaseModule):
eps = torch.randn_like(std)
return mu + eps * std
def generate(self, z):
alt_tensor = self.alt_lin_1(z)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.alt_lin_2(alt_tensor)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_map(alt_tensor)
alt_tensor = self.alt_deconv_1(alt_tensor)
alt_tensor = self.alt_deconv_2(alt_tensor)
alt_tensor = self.alt_deconv_3(alt_tensor)
alt_tensor = self.alt_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
def encode(self, batch_x):
combined_tensor = self.enc_conv_0(batch_x)
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_1a(combined_tensor)
combined_tensor = self.enc_conv_1b(combined_tensor)
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_2b(combined_tensor)
combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_3a(combined_tensor)
combined_tensor = self.enc_conv_3b(combined_tensor)
def encode(self, map_array, trajectory, label):
label_array = torch.cat([torch.full((1, 1, self.in_shape[1], self.in_shape[2]), x.item())
for x in label], dim=0)
label_array = self._move_to_model_device(label_array)
combined_tensor = torch.cat((map_array, trajectory, label_array), dim=1)
combined_tensor = self.map_conv_0(combined_tensor)
combined_tensor = self.map_res_1(combined_tensor)
combined_tensor = self.map_conv_1(combined_tensor)
combined_tensor = self.map_res_2(combined_tensor)
combined_tensor = self.map_conv_2(combined_tensor)
combined_tensor = self.map_res_3(combined_tensor)
combined_tensor = self.map_conv_3(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.map_flat(combined_tensor)
combined_tensor = self.map_lin(combined_tensor)
combined_tensor = self.mixed_lin(combined_tensor)
combined_tensor = self.mixed_norm(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.mixed_lin(combined_tensor)
combined_tensor = self.mixed_norm(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
#
@ -226,19 +241,31 @@ class CNNRouteGeneratorModel(LightningBaseModule):
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def generate_random(self, n=6):
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
def generate(self, z):
alt_tensor = self.gen_lin_1(z)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.gen_lin_2(alt_tensor)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_last_conv(alt_tensor)
alt_tensor = self.gen_deconv_1a(alt_tensor)
alt_tensor = self.gen_deconv_1b(alt_tensor)
alt_tensor = self.gen_deconv_2a(alt_tensor)
alt_tensor = self.gen_deconv_2b(alt_tensor)
alt_tensor = self.gen_deconv_3a(alt_tensor)
alt_tensor = self.gen_deconv_3b(alt_tensor)
alt_tensor = self.gen_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
trajectories = [x.get_random_trajectory() for x in maps]
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
trajectories = self._move_to_model_device(torch.stack(trajectories))
def generate_random(self, n=12):
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
maps = self._move_to_model_device(torch.stack(maps))
samples, alternatives = zip(*[self.dataset.test_dataset[choice]
for choice in choices(range(self.dataset.length), k=n)])
samples = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in samples]))
alternatives = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in alternatives]))
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
return self._test_val_step((samples, alternatives), -9999)
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
@ -329,11 +356,12 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
self._disc = disc_model
def __init__(self, *params):
raise NotImplementedError
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
self._disc = None
self.criterion = nn.BCELoss()
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True,
length=self.hparams.data_param.dataset_length, normalized=True)

View File

@ -0,0 +1,348 @@
from random import choice
import torch
from functools import reduce
from operator import mul
from torch import nn
from torch.optim import Adam
from datasets.trajectory_dataset import TrajData
from lib.evaluation.classification import ROCEvaluation
from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule
from lib.modules.utils import LightningBaseModule, Flatten
import matplotlib.pyplot as plt
class CNNRouteGeneratorModel(LightningBaseModule):
name = 'CNNRouteGenerator'
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.train_param.lr)
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, alternative = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
element_wise_loss = self.criterion(generated_alternative, alternative)
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing TODO: Does This make sense? Sanity Check it!
# kld_loss /= reduce(mul, self.in_shape)
# kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100
loss = (kld_loss + element_wise_loss) / 2
return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss))
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, alternative = batch_xy
map_array = batch_x[0]
trajectory = batch_x[1]
label = batch_x[2].max()
z, _, _ = self.encode(batch_x)
generated_alternative = self.generate(z)
return dict(map_array=map_array, trajectory=trajectory, batch_nb=batch_nb, label=label,
generated_alternative=generated_alternative, pred_label=-1, alternative=alternative
)
def _test_val_epoch_end(self, outputs, test=False):
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(epoch=self.current_epoch)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
def __init__(self, *params, issubclassed=False):
super(CNNRouteGeneratorModel, self).__init__(*params)
if not issubclassed:
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map',
length=self.hparams.data_param.dataset_length, normalized=True)
self.criterion = nn.MSELoss()
# Additional Attributes #
#######################################################
self.map_shape = self.dataset.map_shapes_max
self.trajectory_features = 4
self.res_net = self.hparams.model_param.use_res_net
self.lat_dim = self.hparams.model_param.lat_dim
self.feature_dim = self.lat_dim * 10
########################################################
# NN Nodes
###################################################
#
# Utils
self.activation = nn.ReLU()
self.sigmoid = nn.Sigmoid()
#
# Map Encoder
self.enc_conv_0 = ConvModule(self.map_shape, conv_kernel=3, conv_stride=1, conv_padding=1,
conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[0],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0,
conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1,
conv_padding=2, conv_filters=self.hparams.model_param.filters[1],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1,
conv_padding=3, conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0,
conv_filters=self.hparams.model_param.filters[2],
use_norm=self.hparams.model_param.use_norm,
use_bias=self.hparams.model_param.use_bias)
# Trajectory Encoder
self.env_gru_1 = nn.GRU(input_size=self.trajectory_features, hidden_size=self.feature_dim,
num_layers=3, batch_first=True)
self.enc_flat = Flatten(self.enc_conv_3b.shape)
self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim)
#
# Mixed Encoder
self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim)
self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x
#
# Variational Bottleneck
self.mu = nn.Linear(self.feature_dim, self.lat_dim)
self.logvar = nn.Linear(self.feature_dim, self.lat_dim)
#
# Alternative Generator
self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim)
self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape)
self.gen_gru_x = nn.GRU(None, None, batch_first=True)
def forward(self, batch_x):
#
# Encode
z, mu, logvar = self.encode(batch_x)
#
# Generate
alt_tensor = self.generate(z)
return alt_tensor, z, mu, logvar
@staticmethod
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def encode(self, batch_x):
combined_tensor = self.enc_conv_0(batch_x)
combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_1a(combined_tensor)
combined_tensor = self.enc_conv_1b(combined_tensor)
combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_2a(combined_tensor)
combined_tensor = self.enc_conv_2b(combined_tensor)
combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor
combined_tensor = self.enc_conv_3a(combined_tensor)
combined_tensor = self.enc_conv_3b(combined_tensor)
combined_tensor = self.enc_flat(combined_tensor)
combined_tensor = self.enc_lin_1(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
combined_tensor = self.enc_lin_2(combined_tensor)
combined_tensor = self.enc_norm(combined_tensor)
combined_tensor = self.activation(combined_tensor)
#
# Parameter and Sampling
mu = self.mu(combined_tensor)
logvar = self.logvar(combined_tensor)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def generate(self, z):
alt_tensor = self.gen_lin_1(z)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.gen_lin_2(alt_tensor)
alt_tensor = self.activation(alt_tensor)
alt_tensor = self.reshape_to_last_conv(alt_tensor)
alt_tensor = self.gen_deconv_1a(alt_tensor)
alt_tensor = self.gen_deconv_1b(alt_tensor)
alt_tensor = self.gen_deconv_2a(alt_tensor)
alt_tensor = self.gen_deconv_2b(alt_tensor)
alt_tensor = self.gen_deconv_3a(alt_tensor)
alt_tensor = self.gen_deconv_3b(alt_tensor)
alt_tensor = self.gen_deconv_out(alt_tensor)
# alt_tensor = self.activation(alt_tensor)
alt_tensor = self.sigmoid(alt_tensor)
return alt_tensor
def generate_random(self, n=6):
maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)]
trajectories = [x.get_random_trajectory() for x in maps]
trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories]
trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2
trajectories = self._move_to_model_device(torch.stack(trajectories))
maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2
maps = self._move_to_model_device(torch.stack(maps))
labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n))
return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999)
class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel):
name = 'CNNRouteGeneratorDiscriminated'
def training_step(self, batch_xy, batch_nb, *args, **kwargs):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Dimensional Resizing
kld_loss /= reduce(mul, self.in_shape)
loss = (kld_loss + discriminated_bce_loss) / 2
return dict(loss=loss, log=dict(loss=loss,
discriminated_bce_loss=discriminated_bce_loss,
kld_loss=kld_loss)
)
def _test_val_step(self, batch_xy, batch_nb, *args):
batch_x, label = batch_xy
generated_alternative, z, mu, logvar = self(batch_x)
map_array, trajectory = batch_x
map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1)
pred_label = self.discriminator(map_stack)
discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1))
return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb,
pred_label=pred_label, label=label, generated_alternative=generated_alternative)
def validation_step(self, *args):
return self._test_val_step(*args)
def validation_epoch_end(self, outputs: list):
return self._test_val_epoch_end(outputs)
def _test_val_epoch_end(self, outputs, test=False):
evaluation = ROCEvaluation(plot_roc=True)
pred_label = torch.cat([x['pred_label'] for x in outputs])
labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1)
mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean()
# Sci-py call ROC eval call is eval(true_label, prediction)
roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), )
if test:
# self.logger.log_metrics(score_dict)
self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step)
plt.clf()
maps, trajectories, labels, val_restul_dict = self.generate_random()
from lib.visualization.generator_eval import GeneratorVisualizer
g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict)
fig = g.draw()
self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step)
plt.clf()
return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch)
def test_step(self, *args):
return self._test_val_step(*args)
def test_epoch_end(self, outputs):
return self._test_val_epoch_end(outputs, test=True)
@property
def discriminator(self):
if self._disc is None:
raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)')
return self._disc
def set_discriminator(self, disc_model):
if self._disc is not None:
raise RuntimeError('Discriminator has already been set... What are trying to do?')
self._disc = disc_model
def __init__(self, *params):
super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True)
self._disc = None
self.criterion = nn.BCELoss()
self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route',
length=self.hparams.data_param.dataset_length, normalized=True)

View File

@ -60,7 +60,7 @@ class ConvHomDetector(LightningBaseModule):
super(ConvHomDetector, self).__init__(hparams)
# Dataset
self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map', )
self.dataset = TrajData(self.hparams.data_param.map_root, mode='classifier_all_in_map', )
# Additional Attributes
self.map_shape = self.dataset.map_shapes_max

View File

@ -22,7 +22,7 @@ class Flatten(nn.Module):
try:
x = torch.randn(self.in_shape).unsqueeze(0)
output = self(x)
return output.shape[1:]
return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1]
except Exception as e:
print(e)
return -1

View File

@ -1,10 +1,9 @@
import shelve
from collections import UserDict
from pathlib import Path
import copy
from math import sqrt
from random import choice
from random import Random
import numpy as np
@ -53,8 +52,12 @@ class Map(object):
assert array_like_map_representation.ndim == 3
self.map_array: np.ndarray = array_like_map_representation
self.name = name
self.prng = Random()
pass
def seed(self, seed):
self.prng.seed(seed)
def __setattr__(self, key, value):
super(Map, self).__setattr__(key, value)
if key == 'map_array' and self.map_array is not None:
@ -102,7 +105,7 @@ class Map(object):
return trajectory
def get_valid_position(self):
valid_position = choice(list(self._G.nodes))
valid_position = self.prng.choice(list(self._G.nodes))
return valid_position
def get_trajectory_from_vertices(self, *args):

View File

@ -20,6 +20,8 @@ class Generator:
self.data_root = Path(data_root)
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
kwargs.update(n=m)

22
lib/utils/tools.py Normal file
View File

@ -0,0 +1,22 @@
import pickle
import shelve
from pathlib import Path
def write_to_shelve(file_path, value):
check_path(file_path)
file_path.parent.mkdir(exist_ok=True, parents=True)
with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = str(len(f))
f[new_key] = value
def load_from_shelve(file_path, key):
check_path(file_path)
with shelve.open(str(file_path)) as d:
return d[key]
def check_path(file_path):
assert isinstance(file_path, Path)
assert str(file_path).endswith('.pik')

View File

@ -5,12 +5,13 @@ import lib.variables as V
class GeneratorVisualizer(object):
def __init__(self, maps, trajectories, labels, val_result_dict):
def __init__(self, **kwargs):
# val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative)
self.alternatives = val_result_dict['generated_alternative']
self.labels = labels
self.trajectories = trajectories
self.maps = maps
self.alternatives = kwargs.get('generated_alternative')
self.labels = kwargs.get('labels')
self.trajectories = kwargs.get('trajectories')
self.maps = kwargs.get('maps')
self._map_width, self._map_height = self.maps[0].squeeze().shape
self.column_dict_list = self._build_column_dict_list()
self._cols = len(self.column_dict_list)
@ -24,10 +25,13 @@ class GeneratorVisualizer(object):
for idx in range(self.alternatives.shape[0]):
image = (self.alternatives[idx]).cpu().numpy().squeeze()
label = self.labels[idx].item()
# Dirty and Quick hack incomming.
if label == V.HOMOTOPIC:
hom_alternatives.append(dict(image=image, label='Homotopic'))
non_hom_alternatives.append(None)
else:
non_hom_alternatives.append(dict(image=image, label='NonHomotopic'))
hom_alternatives.append(None)
for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))):
image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze()
label = 'original'
@ -48,10 +52,13 @@ class GeneratorVisualizer(object):
for idx in range(len(grid.axes_all)):
row, col = divmod(idx, len(self.column_dict_list))
current_image = self.column_dict_list[col][row]['image']
current_label = self.column_dict_list[col][row]['label']
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
if self.column_dict_list[col][row] is not None:
current_image = self.column_dict_list[col][row]['image']
current_label = self.column_dict_list[col][row]['label']
grid[idx].imshow(current_image)
grid[idx].title.set_text(current_label)
else:
continue
fig.cbar_mode = 'single'
fig.tight_layout()
return fig

View File

@ -37,6 +37,7 @@ main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000,
main_arg_parser.add_argument("--data_root", type=str, default='data', help="")
main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="")
main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="")
# Transformations
@ -55,9 +56,10 @@ main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerato
main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="")
main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="")
main_arg_parser.add_argument("--model_classes", type=int, default=2, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="")
main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="")
main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="")
main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="")
main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="")
# Project
@ -115,7 +117,7 @@ def run_lightning_loop(config_obj):
# log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting
checkpoint_callback=checkpoint_callback,
logger=logger,
val_percent_check=0.05,
val_percent_check=0.025,
fast_dev_run=config_obj.main.debug,
early_stop_callback=None
)