From bb47e07566a4bd37ee1338f5f32e912f1363a656 Mon Sep 17 00:00:00 2001 From: Si11ium Date: Thu, 12 Mar 2020 18:32:23 +0100 Subject: [PATCH] Offline Datasets res net optionality --- datasets/paired_dataset.py | 2 +- datasets/trajectory_dataset.py | 134 +++++-- lib/models/generators/cnn.py | 226 +++++++----- lib/models/generators/recurrent.py | 348 ++++++++++++++++++ .../homotopy_classification/cnn_based.py | 2 +- lib/modules/utils.py | 2 +- lib/objects/map.py | 9 +- lib/preprocessing/generator.py | 2 + lib/utils/tools.py | 22 ++ lib/visualization/generator_eval.py | 25 +- main.py | 6 +- 11 files changed, 638 insertions(+), 140 deletions(-) create mode 100644 lib/utils/tools.py diff --git a/datasets/paired_dataset.py b/datasets/paired_dataset.py index 722276b..3d152f7 100644 --- a/datasets/paired_dataset.py +++ b/datasets/paired_dataset.py @@ -6,7 +6,7 @@ import torch from torch.utils.data import Dataset, ConcatDataset from datasets.utils import DatasetMapping -from lib.modules.model_parts import Generator +from lib.preprocessing.generator import Generator from lib.objects.map import Map diff --git a/datasets/trajectory_dataset.py b/datasets/trajectory_dataset.py index ff567b6..2b62a24 100644 --- a/datasets/trajectory_dataset.py +++ b/datasets/trajectory_dataset.py @@ -2,15 +2,50 @@ import shelve from pathlib import Path from typing import Union, List +import multiprocessing as mp + import torch from random import choice from torch.utils.data import ConcatDataset, Dataset import numpy as np +from tqdm import tqdm from lib.objects.map import Map import lib.variables as V from PIL import Image +from lib.utils.tools import write_to_shelve + + +class TrajDataShelve(Dataset): + + @property + def map_shape(self): + return self[0][0].shape + + def __init__(self, file_path, **kwargs): + super(TrajDataShelve, self).__init__() + self._mutex = mp.Lock() + self.file_path = str(file_path) + + + def __len__(self): + self._mutex.acquire() + with shelve.open(self.file_path) as d: + length = len(d) + self._mutex.release() + return length + + def seed(self): + pass + + def __getitem__(self, item): + self._mutex.acquire() + with shelve.open(self.file_path) as d: + sample = d[str(item)] + self._mutex.release() + return sample + class TrajDataset(Dataset): @@ -22,14 +57,15 @@ class TrajDataset(Dataset): length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() - assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route'] + assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map' + 'classifier_all_in_map'] self.normalized = normalized self.preserve_equal_samples = preserve_equal_samples self.mode = mode self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length - self.last_label = -1 + self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC]) self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) @@ -39,6 +75,7 @@ class TrajDataset(Dataset): def __getitem__(self, item): if self.mode.lower() == 'just_route': + raise NotImplementedError trajectory = self.map.get_random_trajectory() trajectory_space = trajectory.draw_in_array(self.map.shape) label = choice([0, 1]) @@ -54,37 +91,41 @@ class TrajDataset(Dataset): else: break - self.last_label = label - if self.mode.lower() in ['all_in_map', 'separated_arrays']: + self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE + if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']: map_array = self.map.as_array trajectory = trajectory.draw_in_array(self.map_shape) alternative = alternative.draw_in_array(self.map_shape) - if self.mode == 'separated_arrays': - if self.normalized: - map_array = map_array / V.WHITE - trajectory = trajectory / V.WHITE - alternative = alternative / V.WHITE - return (map_array, trajectory, label), alternative - else: + label_as_array = np.full_like(map_array, label) + if self.normalized: + map_array = map_array / V.WHITE + trajectory = trajectory / V.WHITE + alternative = alternative / V.WHITE + if self.mode == 'generator_all_in_map': + return np.concatenate((map_array, trajectory, label_as_array)), alternative + elif self.mode == 'classifier_all_in_map': return np.concatenate((map_array, trajectory, alternative)), label - elif self.mode == 'vectors': + elif self.mode == '_vectors': + raise NotImplementedError return trajectory.vertices, alternative.vertices, label, self.mapname - else: - raise ValueError + raise ValueError(f'Mode was: {self.mode}') + + def seed(self, seed): + self.map.seed(seed) class TrajData(object): @property def map_shapes(self): - return [dataset.map_shape for dataset in self._dataset.datasets] + return [dataset.map_shape for dataset in self._train_dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) - if self.mode in ['separated_arrays', 'all_in_map']: + if '_all_in_map' in self.mode: shape_list[0] += 2 return shape_list @@ -92,36 +133,81 @@ class TrajData(object): def name(self): return self.__class__.__name__ - def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_): - + def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_): + self.preprocessed = preprocessed self.normalized = normalized self.mode = mode self.maps_root = Path(map_root) self.length = length - self._dataset = self._load_datasets() + self._test_dataset = self._load_datasets('train') + self._val_dataset = self._load_datasets('val') + self._train_dataset = self._load_datasets('test') + + def _load_datasets(self, dataset_type=''): - def _load_datasets(self): map_files = list(self.maps_root.glob('*.bmp')) equal_split = int(self.length // len(map_files)) or 1 # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) + + if self.preprocessed: + preprocessed_map_files = list(self.maps_root.glob('*.pik')) + preprocessed_map_names = [p.name for p in preprocessed_map_files] + datasets = [] + for map_file in map_files: + new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik' + if dataset_type != 'train': + equal_split *= 0.01 + if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]: + traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, + mode=self.mode, embedding_size=max_map_size, normalized=self.normalized, + preserve_equal_samples=True) + self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split) + + dataset = TrajDataShelve(map_file.parent / new_pik_name) + datasets.append(dataset) + return ConcatDataset(datasets) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, mode=self.mode, embedding_size=max_map_size, normalized=self.normalized, preserve_equal_samples=True) for map_file in map_files]) + def kill_em_all(self): + for pik_file in self.maps_root.glob('*.pik'): + pik_file.unlink() + print(pik_file.name, ' was deleted.') + print('Done.') + + def seed(self, seed): + for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]: + dataset.seed(seed) + + def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000): + assert str(file_path).endswith('.pik') + processes = mp.cpu_count() - 1 + mutex = mp.Lock() + with mp.Pool(processes) as pool: + async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)] + + for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'): + sample = result_obj.get() + mutex.acquire() + write_to_shelve(file_path, sample) + mutex.release() + print(f'{n} samples sucessfully dumped to "{file_path}"!') + @property def train_dataset(self): - return self._dataset + return self._train_dataset @property def val_dataset(self): - return self._dataset + return self._val_dataset @property def test_dataset(self): - return self._dataset + return self._test_dataset def get_datasets(self): - return self._dataset, self._dataset, self._dataset + return self._train_dataset, self._val_dataset, self._test_dataset diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index 169aa60..ac55354 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -1,4 +1,5 @@ -from random import choice +from random import choices, seed +import numpy as np import torch from functools import reduce @@ -36,28 +37,36 @@ class CNNRouteGeneratorModel(LightningBaseModule): # kld_loss /= reduce(mul, self.in_shape) # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 - loss = (kld_loss + element_wise_loss) / 2 + loss = kld_loss + element_wise_loss return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, _ = batch_xy - map_array, trajectory, label = batch_x + map_array = batch_x[:, 0].unsqueeze(1) + trajectory = batch_x[:, 1].unsqueeze(1) + labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values - generated_alternative, z, mu, logvar = self(batch_x) - - return dict(batch_nb=batch_nb, label=label, generated_alternative=generated_alternative, pred_label=-1) + _, mu, _ = self.encode(batch_x) + generated_alternative = self.generate(mu) + return dict(maps=map_array, trajectories=trajectory, batch_nb=batch_nb, labels=labels, + generated_alternative=generated_alternative, pred_label=-1) def _test_val_epoch_end(self, outputs, test=False): - maps, trajectories, labels, val_restul_dict = self.generate_random() + val_restul_dict = self.generate_random() from lib.visualization.generator_eval import GeneratorVisualizer - g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) + g = GeneratorVisualizer(**val_restul_dict) fig = g.draw() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) plt.clf() return dict(epoch=self.current_epoch) + def on_epoch_start(self): + self.dataset.seed(self.logger.version) + # torch.random.manual_seed(self.logger.version) + # np.random.seed(self.logger.version) + def validation_step(self, *args): return self._test_val_step(*args) @@ -75,14 +84,18 @@ class CNNRouteGeneratorModel(LightningBaseModule): if not issubclassed: # Dataset - self.dataset = TrajData(self.hparams.data_param.map_root, mode='separated_arrays', + self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map', + preprocessed=self.hparams.data_param.use_preprocessed, length=self.hparams.data_param.dataset_length, normalized=True) self.criterion = nn.MSELoss() - # Additional Attributes + # Additional Attributes # + ####################################################### self.in_shape = self.dataset.map_shapes_max - # Todo: Better naming and size in Parameters - self.feature_dim = self.hparams.model_param.lat_dim * 10 + self.use_res_net = self.hparams.model_param.use_res_net + self.lat_dim = self.hparams.model_param.lat_dim + self.feature_dim = self.lat_dim * 10 + ######################################################## # NN Nodes ################################################### @@ -93,82 +106,100 @@ class CNNRouteGeneratorModel(LightningBaseModule): # # Map Encoder - self.map_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, + self.enc_conv_0 = ConvModule(self.in_shape, conv_kernel=3, conv_stride=1, conv_padding=1, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_res_1 = ResidualModule(self.map_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, + self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[0], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_conv_1 = ConvModule(self.map_res_1.shape, conv_kernel=5, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[1], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) + self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[1], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0, + conv_filters=self.hparams.model_param.filters[1], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) - self.map_res_2 = ResidualModule(self.map_conv_1.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, + self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, conv_padding=2, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_conv_2 = ConvModule(self.map_res_2.shape, conv_kernel=7, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[2], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) + self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) - self.map_res_3 = ResidualModule(self.map_conv_2.shape, ConvModule, 2, conv_kernel=7, conv_stride=1, + self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1, conv_padding=3, conv_filters=self.hparams.model_param.filters[2], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.map_conv_3 = ConvModule(self.map_res_3.shape, conv_kernel=11, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[2], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) + self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) - self.map_flat = Flatten(self.map_conv_3.shape) - self.map_lin = nn.Linear(reduce(mul, self.map_conv_3.shape), self.feature_dim) + self.enc_flat = Flatten(self.enc_conv_3b.shape) + self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim) # # Mixed Encoder - self.mixed_lin = nn.Linear(self.feature_dim, self.feature_dim) - self.mixed_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x + self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim) + self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x # # Variational Bottleneck - self.mu = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim) - self.logvar = nn.Linear(self.feature_dim, self.hparams.model_param.lat_dim) + self.mu = nn.Linear(self.feature_dim, self.lat_dim) + self.logvar = nn.Linear(self.feature_dim, self.lat_dim) # # Alternative Generator - self.alt_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) - # Todo Fix This Hack!!!! - reshape_shape = (1, self.map_conv_3.shape[1], self.map_conv_3.shape[2]) + self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) - self.alt_lin_2 = nn.Linear(self.feature_dim, reduce(mul, reshape_shape)) + self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape) - self.reshape_to_map = Flatten(reduce(mul, reshape_shape), reshape_shape) + self.reshape_to_last_conv = Flatten(self.enc_flat.shape, self.enc_conv_3b.shape) - self.alt_deconv_1 = DeConvModule(reshape_shape, self.hparams.model_param.filters[2], - conv_padding=0, conv_kernel=13, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - self.alt_deconv_2 = DeConvModule(self.alt_deconv_1.shape, self.hparams.model_param.filters[1], - conv_padding=0, conv_kernel=7, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - self.alt_deconv_3 = DeConvModule(self.alt_deconv_2.shape, self.hparams.model_param.filters[0], - conv_padding=1, conv_kernel=5, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - self.alt_deconv_out = DeConvModule(self.alt_deconv_3.shape, 1, activation=None, - conv_padding=1, conv_kernel=3, conv_stride=1, + self.gen_deconv_1a = DeConvModule(self.enc_conv_3b.shape, self.hparams.model_param.filters[2], + conv_padding=0, conv_kernel=11, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + self.gen_deconv_1b = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[2], + conv_padding=0, conv_kernel=9, conv_stride=2, + use_norm=self.hparams.model_param.use_norm) + + self.gen_deconv_2a = DeConvModule(self.gen_deconv_1b.shape, self.hparams.model_param.filters[1], + conv_padding=0, conv_kernel=7, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + self.gen_deconv_2b = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[1], + conv_padding=0, conv_kernel=7, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + + self.gen_deconv_3a = DeConvModule(self.gen_deconv_2b.shape, self.hparams.model_param.filters[0], + conv_padding=1, conv_kernel=5, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + self.gen_deconv_3b = DeConvModule(self.gen_deconv_3a.shape, self.hparams.model_param.filters[0], + conv_padding=1, conv_kernel=4, conv_stride=1, + use_norm=self.hparams.model_param.use_norm) + + self.gen_deconv_out = DeConvModule(self.gen_deconv_3b.shape, 1, activation=None, + conv_padding=0, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) def forward(self, batch_x): - # - # Sorting the Input - map_array, trajectory, label = batch_x - # # Encode - z, mu, logvar = self.encode(map_array, trajectory, label) + z, mu, logvar = self.encode(batch_x) # # Generate @@ -181,42 +212,26 @@ class CNNRouteGeneratorModel(LightningBaseModule): eps = torch.randn_like(std) return mu + eps * std - def generate(self, z): - alt_tensor = self.alt_lin_1(z) - alt_tensor = self.activation(alt_tensor) - alt_tensor = self.alt_lin_2(alt_tensor) - alt_tensor = self.activation(alt_tensor) - alt_tensor = self.reshape_to_map(alt_tensor) - alt_tensor = self.alt_deconv_1(alt_tensor) - alt_tensor = self.alt_deconv_2(alt_tensor) - alt_tensor = self.alt_deconv_3(alt_tensor) - alt_tensor = self.alt_deconv_out(alt_tensor) - # alt_tensor = self.activation(alt_tensor) - alt_tensor = self.sigmoid(alt_tensor) - return alt_tensor + def encode(self, batch_x): + combined_tensor = self.enc_conv_0(batch_x) + combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor + combined_tensor = self.enc_conv_1a(combined_tensor) + combined_tensor = self.enc_conv_1b(combined_tensor) + combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor + combined_tensor = self.enc_conv_2a(combined_tensor) + combined_tensor = self.enc_conv_2b(combined_tensor) + combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor + combined_tensor = self.enc_conv_3a(combined_tensor) + combined_tensor = self.enc_conv_3b(combined_tensor) - def encode(self, map_array, trajectory, label): - label_array = torch.cat([torch.full((1, 1, self.in_shape[1], self.in_shape[2]), x.item()) - for x in label], dim=0) - label_array = self._move_to_model_device(label_array) - combined_tensor = torch.cat((map_array, trajectory, label_array), dim=1) - combined_tensor = self.map_conv_0(combined_tensor) - combined_tensor = self.map_res_1(combined_tensor) - combined_tensor = self.map_conv_1(combined_tensor) - combined_tensor = self.map_res_2(combined_tensor) - combined_tensor = self.map_conv_2(combined_tensor) - combined_tensor = self.map_res_3(combined_tensor) - combined_tensor = self.map_conv_3(combined_tensor) + combined_tensor = self.enc_flat(combined_tensor) + combined_tensor = self.enc_lin_1(combined_tensor) + combined_tensor = self.enc_lin_2(combined_tensor) - combined_tensor = self.map_flat(combined_tensor) - combined_tensor = self.map_lin(combined_tensor) - - combined_tensor = self.mixed_lin(combined_tensor) - - combined_tensor = self.mixed_norm(combined_tensor) + combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) - combined_tensor = self.mixed_lin(combined_tensor) - combined_tensor = self.mixed_norm(combined_tensor) + combined_tensor = self.enc_lin_2(combined_tensor) + combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) # @@ -226,19 +241,31 @@ class CNNRouteGeneratorModel(LightningBaseModule): z = self.reparameterize(mu, logvar) return z, mu, logvar - def generate_random(self, n=6): - maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)] + def generate(self, z): + alt_tensor = self.gen_lin_1(z) + alt_tensor = self.activation(alt_tensor) + alt_tensor = self.gen_lin_2(alt_tensor) + alt_tensor = self.activation(alt_tensor) + alt_tensor = self.reshape_to_last_conv(alt_tensor) + alt_tensor = self.gen_deconv_1a(alt_tensor) + alt_tensor = self.gen_deconv_1b(alt_tensor) + alt_tensor = self.gen_deconv_2a(alt_tensor) + alt_tensor = self.gen_deconv_2b(alt_tensor) + alt_tensor = self.gen_deconv_3a(alt_tensor) + alt_tensor = self.gen_deconv_3b(alt_tensor) + alt_tensor = self.gen_deconv_out(alt_tensor) + # alt_tensor = self.activation(alt_tensor) + alt_tensor = self.sigmoid(alt_tensor) + return alt_tensor - trajectories = [x.get_random_trajectory() for x in maps] - trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories] - trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2 - trajectories = self._move_to_model_device(torch.stack(trajectories)) + def generate_random(self, n=12): - maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2 - maps = self._move_to_model_device(torch.stack(maps)) + samples, alternatives = zip(*[self.dataset.test_dataset[choice] + for choice in choices(range(self.dataset.length), k=n)]) + samples = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in samples])) + alternatives = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in alternatives])) - labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n)) - return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999) + return self._test_val_step((samples, alternatives), -9999) class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): @@ -329,11 +356,12 @@ class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): self._disc = disc_model def __init__(self, *params): + raise NotImplementedError super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True) self._disc = None self.criterion = nn.BCELoss() - self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', + self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True, length=self.hparams.data_param.dataset_length, normalized=True) diff --git a/lib/models/generators/recurrent.py b/lib/models/generators/recurrent.py index e69de29..08ae85d 100644 --- a/lib/models/generators/recurrent.py +++ b/lib/models/generators/recurrent.py @@ -0,0 +1,348 @@ +from random import choice + +import torch +from functools import reduce +from operator import mul + +from torch import nn +from torch.optim import Adam + +from datasets.trajectory_dataset import TrajData +from lib.evaluation.classification import ROCEvaluation +from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule +from lib.modules.utils import LightningBaseModule, Flatten + +import matplotlib.pyplot as plt + + +class CNNRouteGeneratorModel(LightningBaseModule): + + name = 'CNNRouteGenerator' + + def configure_optimizers(self): + return Adam(self.parameters(), lr=self.hparams.train_param.lr) + + def training_step(self, batch_xy, batch_nb, *args, **kwargs): + batch_x, alternative = batch_xy + generated_alternative, z, mu, logvar = self(batch_x) + element_wise_loss = self.criterion(generated_alternative, alternative) + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + + kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + # Dimensional Resizing TODO: Does This make sense? Sanity Check it! + # kld_loss /= reduce(mul, self.in_shape) + # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 + + loss = (kld_loss + element_wise_loss) / 2 + return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) + + def _test_val_step(self, batch_xy, batch_nb, *args): + batch_x, alternative = batch_xy + map_array = batch_x[0] + trajectory = batch_x[1] + label = batch_x[2].max() + + z, _, _ = self.encode(batch_x) + generated_alternative = self.generate(z) + + return dict(map_array=map_array, trajectory=trajectory, batch_nb=batch_nb, label=label, + generated_alternative=generated_alternative, pred_label=-1, alternative=alternative + ) + + def _test_val_epoch_end(self, outputs, test=False): + maps, trajectories, labels, val_restul_dict = self.generate_random() + + from lib.visualization.generator_eval import GeneratorVisualizer + g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) + fig = g.draw() + self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) + plt.clf() + + return dict(epoch=self.current_epoch) + + def validation_step(self, *args): + return self._test_val_step(*args) + + def validation_epoch_end(self, outputs: list): + return self._test_val_epoch_end(outputs) + + def test_step(self, *args): + return self._test_val_step(*args) + + def test_epoch_end(self, outputs): + return self._test_val_epoch_end(outputs, test=True) + + def __init__(self, *params, issubclassed=False): + super(CNNRouteGeneratorModel, self).__init__(*params) + + if not issubclassed: + # Dataset + self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map', + length=self.hparams.data_param.dataset_length, normalized=True) + self.criterion = nn.MSELoss() + + # Additional Attributes # + ####################################################### + self.map_shape = self.dataset.map_shapes_max + self.trajectory_features = 4 + self.res_net = self.hparams.model_param.use_res_net + self.lat_dim = self.hparams.model_param.lat_dim + self.feature_dim = self.lat_dim * 10 + ######################################################## + + # NN Nodes + ################################################### + # + # Utils + self.activation = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + # + # Map Encoder + self.enc_conv_0 = ConvModule(self.map_shape, conv_kernel=3, conv_stride=1, conv_padding=1, + conv_filters=self.hparams.model_param.filters[0], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + + self.enc_res_1 = ResidualModule(self.enc_conv_0.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, + conv_padding=2, conv_filters=self.hparams.model_param.filters[0], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_1a = ConvModule(self.enc_res_1.shape, conv_kernel=3, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[1], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0, + conv_filters=self.hparams.model_param.filters[1], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + + self.enc_res_2 = ResidualModule(self.enc_conv_1b.shape, ConvModule, 2, conv_kernel=5, conv_stride=1, + conv_padding=2, conv_filters=self.hparams.model_param.filters[1], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_2a = ConvModule(self.enc_res_2.shape, conv_kernel=5, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_2b = ConvModule(self.enc_conv_2a.shape, conv_kernel=5, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + + self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1, + conv_padding=3, conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, + conv_filters=self.hparams.model_param.filters[2], + use_norm=self.hparams.model_param.use_norm, + use_bias=self.hparams.model_param.use_bias) + + # Trajectory Encoder + self.env_gru_1 = nn.GRU(input_size=self.trajectory_features, hidden_size=self.feature_dim, + num_layers=3, batch_first=True) + + self.enc_flat = Flatten(self.enc_conv_3b.shape) + self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim) + + # + # Mixed Encoder + self.enc_lin_2 = nn.Linear(self.feature_dim, self.feature_dim) + self.enc_norm = nn.BatchNorm1d(self.feature_dim) if self.hparams.model_param.use_norm else lambda x: x + + # + # Variational Bottleneck + self.mu = nn.Linear(self.feature_dim, self.lat_dim) + self.logvar = nn.Linear(self.feature_dim, self.lat_dim) + + # + # Alternative Generator + self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) + + self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape) + + self.gen_gru_x = nn.GRU(None, None, batch_first=True) + + + + def forward(self, batch_x): + # + # Encode + z, mu, logvar = self.encode(batch_x) + + # + # Generate + alt_tensor = self.generate(z) + return alt_tensor, z, mu, logvar + + @staticmethod + def reparameterize(mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def encode(self, batch_x): + combined_tensor = self.enc_conv_0(batch_x) + combined_tensor = self.enc_res_1(combined_tensor) if self.use_res_net else combined_tensor + combined_tensor = self.enc_conv_1a(combined_tensor) + combined_tensor = self.enc_conv_1b(combined_tensor) + combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor + combined_tensor = self.enc_conv_2a(combined_tensor) + combined_tensor = self.enc_conv_2b(combined_tensor) + combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor + combined_tensor = self.enc_conv_3a(combined_tensor) + combined_tensor = self.enc_conv_3b(combined_tensor) + + combined_tensor = self.enc_flat(combined_tensor) + combined_tensor = self.enc_lin_1(combined_tensor) + combined_tensor = self.enc_lin_2(combined_tensor) + + combined_tensor = self.enc_norm(combined_tensor) + combined_tensor = self.activation(combined_tensor) + combined_tensor = self.enc_lin_2(combined_tensor) + combined_tensor = self.enc_norm(combined_tensor) + combined_tensor = self.activation(combined_tensor) + + # + # Parameter and Sampling + mu = self.mu(combined_tensor) + logvar = self.logvar(combined_tensor) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + + def generate(self, z): + alt_tensor = self.gen_lin_1(z) + alt_tensor = self.activation(alt_tensor) + alt_tensor = self.gen_lin_2(alt_tensor) + alt_tensor = self.activation(alt_tensor) + alt_tensor = self.reshape_to_last_conv(alt_tensor) + alt_tensor = self.gen_deconv_1a(alt_tensor) + alt_tensor = self.gen_deconv_1b(alt_tensor) + alt_tensor = self.gen_deconv_2a(alt_tensor) + alt_tensor = self.gen_deconv_2b(alt_tensor) + alt_tensor = self.gen_deconv_3a(alt_tensor) + alt_tensor = self.gen_deconv_3b(alt_tensor) + alt_tensor = self.gen_deconv_out(alt_tensor) + # alt_tensor = self.activation(alt_tensor) + alt_tensor = self.sigmoid(alt_tensor) + return alt_tensor + + def generate_random(self, n=6): + maps = [self.map_storage[choice(self.map_storage.keys_list)] for _ in range(n)] + + trajectories = [x.get_random_trajectory() for x in maps] + trajectories = [x.draw_in_array(self.map_storage.max_map_size) for x in trajectories] + trajectories = [torch.as_tensor(x, dtype=torch.float32) for x in trajectories] * 2 + trajectories = self._move_to_model_device(torch.stack(trajectories)) + + maps = [torch.as_tensor(x.as_array, dtype=torch.float32) for x in maps] * 2 + maps = self._move_to_model_device(torch.stack(maps)) + + labels = self._move_to_model_device(torch.as_tensor([0] * n + [1] * n)) + return maps, trajectories, labels, self._test_val_step(((maps, trajectories, labels), None), -9999) + + +class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): + + name = 'CNNRouteGeneratorDiscriminated' + + def training_step(self, batch_xy, batch_nb, *args, **kwargs): + batch_x, label = batch_xy + + generated_alternative, z, mu, logvar = self(batch_x) + map_array, trajectory = batch_x + + map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) + pred_label = self.discriminator(map_stack) + discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) + + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + # Dimensional Resizing + kld_loss /= reduce(mul, self.in_shape) + + loss = (kld_loss + discriminated_bce_loss) / 2 + return dict(loss=loss, log=dict(loss=loss, + discriminated_bce_loss=discriminated_bce_loss, + kld_loss=kld_loss) + ) + + def _test_val_step(self, batch_xy, batch_nb, *args): + batch_x, label = batch_xy + + generated_alternative, z, mu, logvar = self(batch_x) + map_array, trajectory = batch_x + + map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) + pred_label = self.discriminator(map_stack) + + discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) + return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb, + pred_label=pred_label, label=label, generated_alternative=generated_alternative) + + def validation_step(self, *args): + return self._test_val_step(*args) + + def validation_epoch_end(self, outputs: list): + return self._test_val_epoch_end(outputs) + + def _test_val_epoch_end(self, outputs, test=False): + evaluation = ROCEvaluation(plot_roc=True) + pred_label = torch.cat([x['pred_label'] for x in outputs]) + labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) + mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() + + # Sci-py call ROC eval call is eval(true_label, prediction) + roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) + if test: + # self.logger.log_metrics(score_dict) + self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) + plt.clf() + + maps, trajectories, labels, val_restul_dict = self.generate_random() + + from lib.visualization.generator_eval import GeneratorVisualizer + g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) + fig = g.draw() + self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) + plt.clf() + + return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) + + def test_step(self, *args): + return self._test_val_step(*args) + + def test_epoch_end(self, outputs): + return self._test_val_epoch_end(outputs, test=True) + + @property + def discriminator(self): + if self._disc is None: + raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)') + return self._disc + + def set_discriminator(self, disc_model): + if self._disc is not None: + raise RuntimeError('Discriminator has already been set... What are trying to do?') + self._disc = disc_model + + def __init__(self, *params): + super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True) + + self._disc = None + + self.criterion = nn.BCELoss() + + self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', + length=self.hparams.data_param.dataset_length, normalized=True) diff --git a/lib/models/homotopy_classification/cnn_based.py b/lib/models/homotopy_classification/cnn_based.py index b662fd2..e3078a0 100644 --- a/lib/models/homotopy_classification/cnn_based.py +++ b/lib/models/homotopy_classification/cnn_based.py @@ -60,7 +60,7 @@ class ConvHomDetector(LightningBaseModule): super(ConvHomDetector, self).__init__(hparams) # Dataset - self.dataset = TrajData(self.hparams.data_param.map_root, mode='all_in_map', ) + self.dataset = TrajData(self.hparams.data_param.map_root, mode='classifier_all_in_map', ) # Additional Attributes self.map_shape = self.dataset.map_shapes_max diff --git a/lib/modules/utils.py b/lib/modules/utils.py index 1450b03..53d86dd 100644 --- a/lib/modules/utils.py +++ b/lib/modules/utils.py @@ -22,7 +22,7 @@ class Flatten(nn.Module): try: x = torch.randn(self.in_shape).unsqueeze(0) output = self(x) - return output.shape[1:] + return output.shape[1:] if len(output.shape[1:]) > 1 else output.shape[-1] except Exception as e: print(e) return -1 diff --git a/lib/objects/map.py b/lib/objects/map.py index 754e9ce..46a0423 100644 --- a/lib/objects/map.py +++ b/lib/objects/map.py @@ -1,10 +1,9 @@ -import shelve from collections import UserDict from pathlib import Path import copy from math import sqrt -from random import choice +from random import Random import numpy as np @@ -53,8 +52,12 @@ class Map(object): assert array_like_map_representation.ndim == 3 self.map_array: np.ndarray = array_like_map_representation self.name = name + self.prng = Random() pass + def seed(self, seed): + self.prng.seed(seed) + def __setattr__(self, key, value): super(Map, self).__setattr__(key, value) if key == 'map_array' and self.map_array is not None: @@ -102,7 +105,7 @@ class Map(object): return trajectory def get_valid_position(self): - valid_position = choice(list(self._G.nodes)) + valid_position = self.prng.choice(list(self._G.nodes)) return valid_position def get_trajectory_from_vertices(self, *args): diff --git a/lib/preprocessing/generator.py b/lib/preprocessing/generator.py index 6794ebf..9831a70 100644 --- a/lib/preprocessing/generator.py +++ b/lib/preprocessing/generator.py @@ -20,6 +20,8 @@ class Generator: self.data_root = Path(data_root) + + def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs): datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik' kwargs.update(n=m) diff --git a/lib/utils/tools.py b/lib/utils/tools.py new file mode 100644 index 0000000..6516bd8 --- /dev/null +++ b/lib/utils/tools.py @@ -0,0 +1,22 @@ +import pickle +import shelve +from pathlib import Path + + +def write_to_shelve(file_path, value): + check_path(file_path) + file_path.parent.mkdir(exist_ok=True, parents=True) + with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f: + new_key = str(len(f)) + f[new_key] = value + + +def load_from_shelve(file_path, key): + check_path(file_path) + with shelve.open(str(file_path)) as d: + return d[key] + + +def check_path(file_path): + assert isinstance(file_path, Path) + assert str(file_path).endswith('.pik') \ No newline at end of file diff --git a/lib/visualization/generator_eval.py b/lib/visualization/generator_eval.py index b82513f..1241717 100644 --- a/lib/visualization/generator_eval.py +++ b/lib/visualization/generator_eval.py @@ -5,12 +5,13 @@ import lib.variables as V class GeneratorVisualizer(object): - def __init__(self, maps, trajectories, labels, val_result_dict): + def __init__(self, **kwargs): # val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative) - self.alternatives = val_result_dict['generated_alternative'] - self.labels = labels - self.trajectories = trajectories - self.maps = maps + self.alternatives = kwargs.get('generated_alternative') + self.labels = kwargs.get('labels') + self.trajectories = kwargs.get('trajectories') + self.maps = kwargs.get('maps') + self._map_width, self._map_height = self.maps[0].squeeze().shape self.column_dict_list = self._build_column_dict_list() self._cols = len(self.column_dict_list) @@ -24,10 +25,13 @@ class GeneratorVisualizer(object): for idx in range(self.alternatives.shape[0]): image = (self.alternatives[idx]).cpu().numpy().squeeze() label = self.labels[idx].item() + # Dirty and Quick hack incomming. if label == V.HOMOTOPIC: hom_alternatives.append(dict(image=image, label='Homotopic')) + non_hom_alternatives.append(None) else: non_hom_alternatives.append(dict(image=image, label='NonHomotopic')) + hom_alternatives.append(None) for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))): image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze() label = 'original' @@ -48,10 +52,13 @@ class GeneratorVisualizer(object): for idx in range(len(grid.axes_all)): row, col = divmod(idx, len(self.column_dict_list)) - current_image = self.column_dict_list[col][row]['image'] - current_label = self.column_dict_list[col][row]['label'] - grid[idx].imshow(current_image) - grid[idx].title.set_text(current_label) + if self.column_dict_list[col][row] is not None: + current_image = self.column_dict_list[col][row]['image'] + current_label = self.column_dict_list[col][row]['label'] + grid[idx].imshow(current_image) + grid[idx].title.set_text(current_label) + else: + continue fig.cbar_mode = 'single' fig.tight_layout() return fig diff --git a/main.py b/main.py index f5b5a23..84fcbde 100644 --- a/main.py +++ b/main.py @@ -37,6 +37,7 @@ main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="") main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") # Transformations @@ -55,9 +56,10 @@ main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerato main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="") main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") -main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="") +main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="") main_arg_parser.add_argument("--model_dropout", type=float, default=0.00, help="") # Project @@ -115,7 +117,7 @@ def run_lightning_loop(config_obj): # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting checkpoint_callback=checkpoint_callback, logger=logger, - val_percent_check=0.05, + val_percent_check=0.025, fast_dev_run=config_obj.main.debug, early_stop_callback=None )