diff --git a/.gitignore b/.gitignore index 04b95f7..473201b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ # User-specific stuff .idea/** +res/** # CMake cmake-build-*/ diff --git a/datasets/mnist.py b/datasets/mnist.py new file mode 100644 index 0000000..66fe06f --- /dev/null +++ b/datasets/mnist.py @@ -0,0 +1,29 @@ +from torchvision.datasets import MNIST +import numpy as np + + +class MyMNIST(MNIST): + + @property + def map_shapes_max(self): + return np.asarray(self.test_dataset[0][0]).shape + + def __init__(self, *args, **kwargs): + super(MyMNIST, self).__init__('res', train=False, download=True) + pass + + def __getitem__(self, item): + image = super(MyMNIST, self).__getitem__(item) + return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1] + + @property + def train_dataset(self): + return self.__class__('res', train=True, download=True) + + @property + def test_dataset(self): + return self.__class__('res', train=False, download=True) + + @property + def val_dataset(self): + return self.__class__('res', train=False, download=True) diff --git a/datasets/trajectory_dataset.py b/datasets/trajectory_dataset.py index 2b62a24..f48113c 100644 --- a/datasets/trajectory_dataset.py +++ b/datasets/trajectory_dataset.py @@ -1,6 +1,9 @@ import shelve +from collections import defaultdict from pathlib import Path -from typing import Union, List +from typing import Union + +from torchvision.transforms import Normalize import multiprocessing as mp @@ -24,16 +27,17 @@ class TrajDataShelve(Dataset): return self[0][0].shape def __init__(self, file_path, **kwargs): + assert Path(file_path).exists() super(TrajDataShelve, self).__init__() self._mutex = mp.Lock() self.file_path = str(file_path) - def __len__(self): self._mutex.acquire() with shelve.open(self.file_path) as d: length = len(d) - self._mutex.release() + d.close() + self._mutex.release() return length def seed(self): @@ -43,12 +47,20 @@ class TrajDataShelve(Dataset): self._mutex.acquire() with shelve.open(self.file_path) as d: sample = d[str(item)] - self._mutex.release() + d.close() + self._mutex.release() return sample class TrajDataset(Dataset): + @property + def _last_label_init(self): + d = defaultdict(lambda: -1) + d['generator_hom_all_in_map'] = V.ALTERNATIVE + d['generator_alt_all_in_map'] = V.HOMOTOPIC + return d[self.mode] + @property def map_shape(self): return self.map.as_array.shape @@ -57,17 +69,18 @@ class TrajDataset(Dataset): length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() - assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map' - 'classifier_all_in_map'] - self.normalized = normalized + assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map', + 'ae_no_label_in_map', + 'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map'] + self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x self.preserve_equal_samples = preserve_equal_samples self.mode = mode self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length - self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC]) + self.last_label = self._last_label_init - self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) + self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size) def __len__(self): return self._len @@ -82,6 +95,7 @@ class TrajDataset(Dataset): map_array = torch.as_tensor(self.map.as_array).float() return (map_array, trajectory_space), label + # Produce an alternative. while True: trajectory = self.map.get_random_trajectory() alternative = self.map.generate_alternative(trajectory) @@ -91,18 +105,19 @@ class TrajDataset(Dataset): else: break - self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE - if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']: + self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode] + if 'in_map' in self.mode.lower(): map_array = self.map.as_array trajectory = trajectory.draw_in_array(self.map_shape) alternative = alternative.draw_in_array(self.map_shape) label_as_array = np.full_like(map_array, label) - if self.normalized: - map_array = map_array / V.WHITE - trajectory = trajectory / V.WHITE - alternative = alternative / V.WHITE + if self.mode == 'generator_all_in_map': return np.concatenate((map_array, trajectory, label_as_array)), alternative + elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']: + return np.sum((map_array, trajectory, alternative), axis=0), 0 + elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']: + return np.concatenate((map_array, trajectory)), alternative elif self.mode == 'classifier_all_in_map': return np.concatenate((map_array, trajectory, alternative)), label @@ -119,13 +134,13 @@ class TrajDataset(Dataset): class TrajData(object): @property def map_shapes(self): - return [dataset.map_shape for dataset in self._train_dataset.datasets] + return [dataset.map_shape for dataset in self.train_dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) - if '_all_in_map' in self.mode: + if '_all_in_map' in self.mode and not self.preprocessed: shape_list[0] += 2 return shape_list @@ -139,14 +154,13 @@ class TrajData(object): self.mode = mode self.maps_root = Path(map_root) self.length = length - self._test_dataset = self._load_datasets('train') - self._val_dataset = self._load_datasets('val') - self._train_dataset = self._load_datasets('test') + self.test_dataset = self._load_datasets('test') + self.val_dataset = self._load_datasets('val') + self.train_dataset = self._load_datasets('train') def _load_datasets(self, dataset_type=''): map_files = list(self.maps_root.glob('*.bmp')) - equal_split = int(self.length // len(map_files)) or 1 # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) @@ -156,10 +170,11 @@ class TrajData(object): preprocessed_map_names = [p.name for p in preprocessed_map_files] datasets = [] for map_file in map_files: - new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik' + equal_split = int(self.length // len(map_files)) or 5 + new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik' if dataset_type != 'train': - equal_split *= 0.01 - if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]: + equal_split = max(int(equal_split * 0.01), 10) + if not new_pik_name in preprocessed_map_names: traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, mode=self.mode, embedding_size=max_map_size, normalized=self.normalized, preserve_equal_samples=True) @@ -168,6 +183,9 @@ class TrajData(object): dataset = TrajDataShelve(map_file.parent / new_pik_name) datasets.append(dataset) return ConcatDataset(datasets) + + # Set the equal split so that all maps are visited with the same frequency + equal_split = int(self.length // len(map_files)) or 5 return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, mode=self.mode, embedding_size=max_map_size, normalized=self.normalized, preserve_equal_samples=True) @@ -185,29 +203,14 @@ class TrajData(object): def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000): assert str(file_path).endswith('.pik') - processes = mp.cpu_count() - 1 mutex = mp.Lock() - with mp.Pool(processes) as pool: - async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)] + for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'): + sample = traj_dataset[i] + mutex.acquire() + write_to_shelve(file_path, sample) + mutex.release() - for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'): - sample = result_obj.get() - mutex.acquire() - write_to_shelve(file_path, sample) - mutex.release() - print(f'{n} samples sucessfully dumped to "{file_path}"!') - - @property - def train_dataset(self): - return self._train_dataset - - @property - def val_dataset(self): - return self._val_dataset - - @property - def test_dataset(self): - return self._test_dataset + print(f'{n} samples successfully dumped to "{file_path}"!') def get_datasets(self): return self._train_dataset, self._val_dataset, self._test_dataset diff --git a/lib/models/generators/cnn.py b/lib/models/generators/cnn.py index ac55354..076f1a6 100644 --- a/lib/models/generators/cnn.py +++ b/lib/models/generators/cnn.py @@ -1,19 +1,22 @@ -from random import choices, seed -import numpy as np - -import torch from functools import reduce from operator import mul +from random import choices, choice + +import torch + from torch import nn from torch.optim import Adam +from torchvision.datasets import MNIST +from datasets.mnist import MyMNIST from datasets.trajectory_dataset import TrajData -from lib.evaluation.classification import ROCEvaluation from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule from lib.modules.utils import LightningBaseModule, Flatten import matplotlib.pyplot as plt +import lib.variables as V +from lib.visualization.generator_eval import GeneratorVisualizer class CNNRouteGeneratorModel(LightningBaseModule): @@ -24,48 +27,71 @@ class CNNRouteGeneratorModel(LightningBaseModule): return Adam(self.parameters(), lr=self.hparams.train_param.lr) def training_step(self, batch_xy, batch_nb, *args, **kwargs): - batch_x, alternative = batch_xy + batch_x, target = batch_xy generated_alternative, z, mu, logvar = self(batch_x) - element_wise_loss = self.criterion(generated_alternative, alternative) - # see Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + target = batch_x if 'ae' in self.hparams.data_param.mode else target + element_wise_loss = self.criterion(generated_alternative, target) - kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - # Dimensional Resizing TODO: Does This make sense? Sanity Check it! - # kld_loss /= reduce(mul, self.in_shape) - # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size * 100 + if 'vae' in self.hparams.data_param.mode: + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + # Dimensional Resizing TODO: Does This make sense? Sanity Check it! + # kld_loss /= reduce(mul, self.in_shape) + # kld_loss *= self.hparams.data_param.dataset_length / self.hparams.train_param.batch_size - loss = kld_loss + element_wise_loss + loss = kld_loss + element_wise_loss + else: + loss = element_wise_loss + kld_loss = 0 return dict(loss=loss, log=dict(element_wise_loss=element_wise_loss, loss=loss, kld_loss=kld_loss)) def _test_val_step(self, batch_xy, batch_nb, *args): batch_x, _ = batch_xy - map_array = batch_x[:, 0].unsqueeze(1) - trajectory = batch_x[:, 1].unsqueeze(1) - labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values + if 'vae' in self.hparams.data_param.mode: + z, mu, logvar = self.encode(batch_x) + else: + z = self.encode(batch_x) + mu, logvar = z, z - _, mu, _ = self.encode(batch_x) generated_alternative = self.generate(mu) - return dict(maps=map_array, trajectories=trajectory, batch_nb=batch_nb, labels=labels, - generated_alternative=generated_alternative, pred_label=-1) + return_dict = dict(input=batch_x, batch_nb=batch_nb, output=generated_alternative, z=z, mu=mu, logvar=logvar) + + if 'hom' in self.hparams.data_param.mode: + labels = torch.full((batch_x.shape[0], 1), V.HOMOTOPIC) + elif 'alt' in self.hparams.data_param.mode: + labels = torch.full((batch_x.shape[0], 1), V.ALTERNATIVE) + elif 'vae' in self.hparams.data_param.mode: + labels = torch.full((batch_x.shape[0], 1), V.ANY) + elif 'ae' in self.hparams.data_param.mode: + labels = torch.full((batch_x.shape[0], 1), V.ANY) + else: + labels = batch_x[:, 2].unsqueeze(1).max(dim=-1).values.max(-1).values + + return_dict.update(labels=self._move_to_model_device(labels)) + return return_dict def _test_val_epoch_end(self, outputs, test=False): - val_restul_dict = self.generate_random() + plt.close('all') - from lib.visualization.generator_eval import GeneratorVisualizer - g = GeneratorVisualizer(**val_restul_dict) - fig = g.draw() + g = GeneratorVisualizer(choice(outputs)) + fig = g.draw_io_bundle() self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) plt.clf() + fig = g.draw_latent() + self.logger.log_image(f'{self.name}_Latent', fig, step=self.global_step) + plt.clf() + return dict(epoch=self.current_epoch) def on_epoch_start(self): - self.dataset.seed(self.logger.version) + # self.dataset.seed(self.logger.version) # torch.random.manual_seed(self.logger.version) # np.random.seed(self.logger.version) + pass def validation_step(self, *args): return self._test_val_step(*args) @@ -82,19 +108,23 @@ class CNNRouteGeneratorModel(LightningBaseModule): def __init__(self, *params, issubclassed=False): super(CNNRouteGeneratorModel, self).__init__(*params) - if not issubclassed: + if False: # Dataset - self.dataset = TrajData(self.hparams.data_param.map_root, mode='generator_all_in_map', + self.dataset = TrajData(self.hparams.data_param.map_root, + mode=self.hparams.data_param.mode, preprocessed=self.hparams.data_param.use_preprocessed, length=self.hparams.data_param.dataset_length, normalized=True) - self.criterion = nn.MSELoss() + self.criterion = nn.MSELoss() + + self.dataset = MyMNIST() # Additional Attributes # ####################################################### self.in_shape = self.dataset.map_shapes_max self.use_res_net = self.hparams.model_param.use_res_net self.lat_dim = self.hparams.model_param.lat_dim - self.feature_dim = self.lat_dim * 10 + self.feature_dim = self.lat_dim + self.out_channels = 1 if 'generator' in self.hparams.data_param.mode else self.in_shape[0] ######################################################## # NN Nodes @@ -119,7 +149,7 @@ class CNNRouteGeneratorModel(LightningBaseModule): conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=2, conv_padding=0, + self.enc_conv_1b = ConvModule(self.enc_conv_1a.shape, conv_kernel=3, conv_stride=1, conv_padding=0, conv_filters=self.hparams.model_param.filters[1], use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) @@ -137,20 +167,8 @@ class CNNRouteGeneratorModel(LightningBaseModule): use_norm=self.hparams.model_param.use_norm, use_bias=self.hparams.model_param.use_bias) - self.enc_res_3 = ResidualModule(self.enc_conv_2b.shape, ConvModule, 2, conv_kernel=7, conv_stride=1, - conv_padding=3, conv_filters=self.hparams.model_param.filters[2], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - self.enc_conv_3a = ConvModule(self.enc_res_3.shape, conv_kernel=7, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[2], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - self.enc_conv_3b = ConvModule(self.enc_conv_3a.shape, conv_kernel=7, conv_stride=1, conv_padding=0, - conv_filters=self.hparams.model_param.filters[2], - use_norm=self.hparams.model_param.use_norm, - use_bias=self.hparams.model_param.use_bias) - - self.enc_flat = Flatten(self.enc_conv_3b.shape) + last_conv_shape = self.enc_conv_2b.shape + self.enc_flat = Flatten(last_conv_shape) self.enc_lin_1 = nn.Linear(self.enc_flat.shape, self.feature_dim) # @@ -160,46 +178,43 @@ class CNNRouteGeneratorModel(LightningBaseModule): # # Variational Bottleneck - self.mu = nn.Linear(self.feature_dim, self.lat_dim) - self.logvar = nn.Linear(self.feature_dim, self.lat_dim) + if 'vae' in self.hparams.data_param.mode: + self.mu = nn.Linear(self.feature_dim, self.lat_dim) + self.logvar = nn.Linear(self.feature_dim, self.lat_dim) + + # + # Linear Bottleneck + else: + self.z = nn.Linear(self.feature_dim, self.lat_dim) # # Alternative Generator - self.gen_lin_1 = nn.Linear(self.hparams.model_param.lat_dim, self.feature_dim) + self.gen_lin_1 = nn.Linear(self.lat_dim, self.enc_flat.shape) - self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape) + # self.gen_lin_2 = nn.Linear(self.feature_dim, self.enc_flat.shape) - self.reshape_to_last_conv = Flatten(self.enc_flat.shape, self.enc_conv_3b.shape) + self.reshape_to_last_conv = Flatten(self.enc_flat.shape, last_conv_shape) - self.gen_deconv_1a = DeConvModule(self.enc_conv_3b.shape, self.hparams.model_param.filters[2], - conv_padding=0, conv_kernel=11, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - self.gen_deconv_1b = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[2], - conv_padding=0, conv_kernel=9, conv_stride=2, + self.gen_deconv_1a = DeConvModule(last_conv_shape, self.hparams.model_param.filters[2], + conv_padding=1, conv_kernel=9, conv_stride=1, use_norm=self.hparams.model_param.use_norm) - self.gen_deconv_2a = DeConvModule(self.gen_deconv_1b.shape, self.hparams.model_param.filters[1], - conv_padding=0, conv_kernel=7, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - self.gen_deconv_2b = DeConvModule(self.gen_deconv_2a.shape, self.hparams.model_param.filters[1], - conv_padding=0, conv_kernel=7, conv_stride=1, + self.gen_deconv_2a = DeConvModule(self.gen_deconv_1a.shape, self.hparams.model_param.filters[1], + conv_padding=1, conv_kernel=7, conv_stride=1, use_norm=self.hparams.model_param.use_norm) - self.gen_deconv_3a = DeConvModule(self.gen_deconv_2b.shape, self.hparams.model_param.filters[0], - conv_padding=1, conv_kernel=5, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - self.gen_deconv_3b = DeConvModule(self.gen_deconv_3a.shape, self.hparams.model_param.filters[0], - conv_padding=1, conv_kernel=4, conv_stride=1, - use_norm=self.hparams.model_param.use_norm) - - self.gen_deconv_out = DeConvModule(self.gen_deconv_3b.shape, 1, activation=None, + self.gen_deconv_out = DeConvModule(self.gen_deconv_2a.shape, self.out_channels, activation=None, conv_padding=0, conv_kernel=3, conv_stride=1, use_norm=self.hparams.model_param.use_norm) def forward(self, batch_x): # # Encode - z, mu, logvar = self.encode(batch_x) + if 'vae' in self.hparams.data_param.mode: + z, mu, logvar = self.encode(batch_x) + else: + z = self.encode(batch_x) + mu, logvar = z, z # # Generate @@ -220,148 +235,46 @@ class CNNRouteGeneratorModel(LightningBaseModule): combined_tensor = self.enc_res_2(combined_tensor) if self.use_res_net else combined_tensor combined_tensor = self.enc_conv_2a(combined_tensor) combined_tensor = self.enc_conv_2b(combined_tensor) - combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor - combined_tensor = self.enc_conv_3a(combined_tensor) - combined_tensor = self.enc_conv_3b(combined_tensor) + # combined_tensor = self.enc_res_3(combined_tensor) if self.use_res_net else combined_tensor + # combined_tensor = self.enc_conv_3a(combined_tensor) + # combined_tensor = self.enc_conv_3b(combined_tensor) combined_tensor = self.enc_flat(combined_tensor) combined_tensor = self.enc_lin_1(combined_tensor) - combined_tensor = self.enc_lin_2(combined_tensor) - combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) + combined_tensor = self.enc_lin_2(combined_tensor) combined_tensor = self.enc_norm(combined_tensor) combined_tensor = self.activation(combined_tensor) # + # Variational # Parameter and Sampling - mu = self.mu(combined_tensor) - logvar = self.logvar(combined_tensor) - z = self.reparameterize(mu, logvar) - return z, mu, logvar + if 'vae' in self.hparams.data_param.mode: + mu = self.mu(combined_tensor) + logvar = self.logvar(combined_tensor) + z = self.reparameterize(mu, logvar) + return z, mu, logvar + else: + # + # Linear Bottleneck + z = self.z(combined_tensor) + return z def generate(self, z): alt_tensor = self.gen_lin_1(z) alt_tensor = self.activation(alt_tensor) - alt_tensor = self.gen_lin_2(alt_tensor) - alt_tensor = self.activation(alt_tensor) + # alt_tensor = self.gen_lin_2(alt_tensor) + # alt_tensor = self.activation(alt_tensor) alt_tensor = self.reshape_to_last_conv(alt_tensor) alt_tensor = self.gen_deconv_1a(alt_tensor) - alt_tensor = self.gen_deconv_1b(alt_tensor) + alt_tensor = self.gen_deconv_2a(alt_tensor) - alt_tensor = self.gen_deconv_2b(alt_tensor) - alt_tensor = self.gen_deconv_3a(alt_tensor) - alt_tensor = self.gen_deconv_3b(alt_tensor) + + # alt_tensor = self.gen_deconv_3a(alt_tensor) + # alt_tensor = self.gen_deconv_3b(alt_tensor) alt_tensor = self.gen_deconv_out(alt_tensor) # alt_tensor = self.activation(alt_tensor) - alt_tensor = self.sigmoid(alt_tensor) + # alt_tensor = self.sigmoid(alt_tensor) return alt_tensor - - def generate_random(self, n=12): - - samples, alternatives = zip(*[self.dataset.test_dataset[choice] - for choice in choices(range(self.dataset.length), k=n)]) - samples = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in samples])) - alternatives = self._move_to_model_device(torch.stack([torch.as_tensor(x) for x in alternatives])) - - return self._test_val_step((samples, alternatives), -9999) - - -class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): - - name = 'CNNRouteGeneratorDiscriminated' - - def training_step(self, batch_xy, batch_nb, *args, **kwargs): - batch_x, label = batch_xy - - generated_alternative, z, mu, logvar = self(batch_x) - map_array, trajectory = batch_x - - map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) - pred_label = self.discriminator(map_stack) - discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) - - # see Appendix B from VAE paper: - # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - # https://arxiv.org/abs/1312.6114 - # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - # Dimensional Resizing - kld_loss /= reduce(mul, self.in_shape) - - loss = (kld_loss + discriminated_bce_loss) / 2 - return dict(loss=loss, log=dict(loss=loss, - discriminated_bce_loss=discriminated_bce_loss, - kld_loss=kld_loss) - ) - - def _test_val_step(self, batch_xy, batch_nb, *args): - batch_x, label = batch_xy - - generated_alternative, z, mu, logvar = self(batch_x) - map_array, trajectory = batch_x - - map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) - pred_label = self.discriminator(map_stack) - - discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) - return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb, - pred_label=pred_label, label=label, generated_alternative=generated_alternative) - - def validation_step(self, *args): - return self._test_val_step(*args) - - def validation_epoch_end(self, outputs: list): - return self._test_val_epoch_end(outputs) - - def _test_val_epoch_end(self, outputs, test=False): - evaluation = ROCEvaluation(plot_roc=True) - pred_label = torch.cat([x['pred_label'] for x in outputs]) - labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) - mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() - - # Sci-py call ROC eval call is eval(true_label, prediction) - roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) - if test: - # self.logger.log_metrics(score_dict) - self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) - plt.clf() - - maps, trajectories, labels, val_restul_dict = self.generate_random() - - from lib.visualization.generator_eval import GeneratorVisualizer - g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) - fig = g.draw() - self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) - plt.clf() - - return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) - - def test_step(self, *args): - return self._test_val_step(*args) - - def test_epoch_end(self, outputs): - return self._test_val_epoch_end(outputs, test=True) - - @property - def discriminator(self): - if self._disc is None: - raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)') - return self._disc - - def set_discriminator(self, disc_model): - if self._disc is not None: - raise RuntimeError('Discriminator has already been set... What are trying to do?') - self._disc = disc_model - - def __init__(self, *params): - raise NotImplementedError - super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True) - - self._disc = None - - self.criterion = nn.BCELoss() - - self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True, - length=self.hparams.data_param.dataset_length, normalized=True) diff --git a/lib/models/generators/cnn_discriminated.py b/lib/models/generators/cnn_discriminated.py new file mode 100644 index 0000000..9857a01 --- /dev/null +++ b/lib/models/generators/cnn_discriminated.py @@ -0,0 +1,116 @@ +from random import choices, seed +import numpy as np + +import torch +from functools import reduce +from operator import mul + +from torch import nn +from torch.optim import Adam + +from datasets.trajectory_dataset import TrajData +from lib.evaluation.classification import ROCEvaluation +from lib.models.generators.cnn import CNNRouteGeneratorModel +from lib.modules.blocks import ConvModule, ResidualModule, DeConvModule +from lib.modules.utils import LightningBaseModule, Flatten + +import matplotlib.pyplot as plt + + +class CNNRouteGeneratorDiscriminated(CNNRouteGeneratorModel): + + name = 'CNNRouteGeneratorDiscriminated' + + def training_step(self, batch_xy, batch_nb, *args, **kwargs): + batch_x, label = batch_xy + + generated_alternative, z, mu, logvar = self(batch_x) + map_array, trajectory = batch_x + + map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) + pred_label = self.discriminator(map_stack) + discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) + + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + # Dimensional Resizing + kld_loss /= reduce(mul, self.in_shape) + + loss = (kld_loss + discriminated_bce_loss) / 2 + return dict(loss=loss, log=dict(loss=loss, + discriminated_bce_loss=discriminated_bce_loss, + kld_loss=kld_loss) + ) + + def _test_val_step(self, batch_xy, batch_nb, *args): + batch_x, label = batch_xy + + generated_alternative, z, mu, logvar = self(batch_x) + map_array, trajectory = batch_x + + map_stack = torch.cat((map_array, trajectory, generated_alternative), dim=1) + pred_label = self.discriminator(map_stack) + + discriminated_bce_loss = self.criterion(pred_label, label.float().unsqueeze(-1)) + return dict(discriminated_bce_loss=discriminated_bce_loss, batch_nb=batch_nb, + pred_label=pred_label, label=label, generated_alternative=generated_alternative) + + def validation_step(self, *args): + return self._test_val_step(*args) + + def validation_epoch_end(self, outputs: list): + return self._test_val_epoch_end(outputs) + + def _test_val_epoch_end(self, outputs, test=False): + evaluation = ROCEvaluation(plot_roc=True) + pred_label = torch.cat([x['pred_label'] for x in outputs]) + labels = torch.cat([x['label'] for x in outputs]).unsqueeze(1) + mean_losses = torch.stack([x['discriminated_bce_loss'] for x in outputs]).mean() + + # Sci-py call ROC eval call is eval(true_label, prediction) + roc_auc, tpr, fpr = evaluation(labels.cpu().numpy(), pred_label.cpu().numpy(), ) + if test: + # self.logger.log_metrics(score_dict) + self.logger.log_image(f'{self.name}_ROC-Curve', plt.gcf(), step=self.global_step) + plt.clf() + + maps, trajectories, labels, val_restul_dict = self.generate_random() + + from lib.visualization.generator_eval import GeneratorVisualizer + g = GeneratorVisualizer(maps, trajectories, labels, val_restul_dict) + fig = g.draw() + self.logger.log_image(f'{self.name}_Output', fig, step=self.global_step) + plt.clf() + + return dict(mean_losses=mean_losses, roc_auc=roc_auc, epoch=self.current_epoch) + + def test_step(self, *args): + return self._test_val_step(*args) + + def test_epoch_end(self, outputs): + return self._test_val_epoch_end(outputs, test=True) + + @property + def discriminator(self): + if self._disc is None: + raise RuntimeError('Set the Discriminator first; "set_discriminator(disc_model)') + return self._disc + + def set_discriminator(self, disc_model): + if self._disc is not None: + raise RuntimeError('Discriminator has already been set... What are trying to do?') + self._disc = disc_model + + def __init__(self, *params): + raise NotImplementedError + super(CNNRouteGeneratorDiscriminated, self).__init__(*params, issubclassed=True) + + self._disc = None + + self.criterion = nn.BCELoss() + + self.dataset = TrajData(self.hparams.data_param.map_root, mode='just_route', preprocessed=True, + length=self.hparams.data_param.dataset_length, normalized=True) diff --git a/lib/objects/map.py b/lib/objects/map.py index 46a0423..726d292 100644 --- a/lib/objects/map.py +++ b/lib/objects/map.py @@ -189,5 +189,5 @@ class MapStorage(UserDict): ) for map_file in map_files: - current_map = Map().from_image(map_file, embedding_size=self.max_map_size) + current_map = Map.from_image(map_file, embedding_size=self.max_map_size) self.__setitem__(map_file.name, current_map) diff --git a/lib/utils/config.py b/lib/utils/config.py index ee790e7..673bcfb 100644 --- a/lib/utils/config.py +++ b/lib/utils/config.py @@ -5,7 +5,9 @@ from collections import defaultdict from configparser import ConfigParser from pathlib import Path -from lib.models.generators.cnn import CNNRouteGeneratorModel, CNNRouteGeneratorDiscriminated +from lib.models.generators.cnn import CNNRouteGeneratorModel +from lib.models.generators.cnn_discriminated import CNNRouteGeneratorDiscriminated + from lib.models.homotopy_classification.cnn_based import ConvHomDetector from lib.utils.model_io import ModelParameters from lib.utils.transforms import AsArray diff --git a/lib/utils/logging.py b/lib/utils/logging.py index deb3d38..8050687 100644 --- a/lib/utils/logging.py +++ b/lib/utils/logging.py @@ -37,7 +37,7 @@ class Logger(LightningLoggerBase): @property def outpath(self): # ToDo: Add further path modification such as dataset config etc. - return Path(self.config.train.outpath) + return Path(self.config.train.outpath) / self.config.data.mode def __init__(self, config: Config): """ diff --git a/lib/utils/tools.py b/lib/utils/tools.py index 6516bd8..594fff4 100644 --- a/lib/utils/tools.py +++ b/lib/utils/tools.py @@ -9,6 +9,7 @@ def write_to_shelve(file_path, value): with shelve.open(str(file_path), protocol=pickle.HIGHEST_PROTOCOL) as f: new_key = str(len(f)) f[new_key] = value + f.close() def load_from_shelve(file_path, key): diff --git a/lib/variables.py b/lib/variables.py index ab97da8..9b27a6c 100644 --- a/lib/variables.py +++ b/lib/variables.py @@ -1,9 +1,15 @@ from pathlib import Path _ROOT = Path('..') +# Labels for classes HOMOTOPIC = 1 ALTERNATIVE = 0 +ANY = -1 + +# Colors for img files WHITE = 255 BLACK = 0 -DPI = 100 +# Variables for plotting +PADDING = 0.25 +DPI = 50 diff --git a/lib/visualization/generator_eval.py b/lib/visualization/generator_eval.py index 1241717..b895816 100644 --- a/lib/visualization/generator_eval.py +++ b/lib/visualization/generator_eval.py @@ -1,53 +1,106 @@ +from collections import defaultdict + import matplotlib.pyplot as plt +import matplotlib.cm as cmaps from mpl_toolkits.axisartist.axes_grid import ImageGrid +from sklearn.cluster import Birch, DBSCAN, KMeans +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE + import lib.variables as V +import numpy as np class GeneratorVisualizer(object): - def __init__(self, **kwargs): - # val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative) - self.alternatives = kwargs.get('generated_alternative') - self.labels = kwargs.get('labels') - self.trajectories = kwargs.get('trajectories') - self.maps = kwargs.get('maps') + def __init__(self, outputs, k=8): + d = defaultdict(list) + for key in outputs.keys(): + try: + d[key] = outputs[key][:k].cpu().numpy() + except AttributeError: + d[key] = outputs[key][:k] + except TypeError: + self.batch_nb = outputs[key] + for key in d.keys(): + self.__setattr__(key, d[key]) - self._map_width, self._map_height = self.maps[0].squeeze().shape + # val_results = dict(discriminated_bce_loss, batch_nb, pred_label, label, generated_alternative) + self._map_width, self._map_height = self.input.shape[1], self.input.shape[2] self.column_dict_list = self._build_column_dict_list() self._cols = len(self.column_dict_list) self._rows = len(self.column_dict_list[0]) + self.colormap = cmaps.tab20 + def _build_column_dict_list(self): trajectories = [] - non_hom_alternatives = [] - hom_alternatives = [] + alternatives = [] - for idx in range(self.alternatives.shape[0]): - image = (self.alternatives[idx]).cpu().numpy().squeeze() - label = self.labels[idx].item() - # Dirty and Quick hack incomming. - if label == V.HOMOTOPIC: - hom_alternatives.append(dict(image=image, label='Homotopic')) - non_hom_alternatives.append(None) - else: - non_hom_alternatives.append(dict(image=image, label='NonHomotopic')) - hom_alternatives.append(None) - for idx in range(max(len(hom_alternatives), len(non_hom_alternatives))): - image = (self.maps[idx] + self.trajectories[idx]).cpu().numpy().squeeze() + for idx in range(self.output.shape[0]): + image = (self.output[idx]).squeeze() + label = 'Homotopic' if self.labels[idx].item() == V.HOMOTOPIC else 'Alternative' + alternatives.append(dict(image=image, label=label)) + + for idx in range(len(alternatives)): + image = (self.input[idx]).squeeze() label = 'original' trajectories.append(dict(image=image, label=label)) - return trajectories, hom_alternatives, non_hom_alternatives + return trajectories, alternatives - def draw(self): - padding = 0.25 - additional_size = self._cols * padding + 3 * padding - width = (self._map_width * self._cols) / V.DPI + additional_size - height = (self._map_height * self._rows) / V.DPI + additional_size + @staticmethod + def cluster_data(data): + + cluster = Birch() + + labels = cluster.fit_predict(data) + return labels + + def draw_latent(self): + plt.close('all') + clusterer = KMeans(10) + try: + labels = clusterer.fit_predict(self.logvar) + except ValueError: + fig = plt.figure() + return fig + if self.z.shape[-1] > 2: + fig, axs = plt.subplots(ncols=2, nrows=1) + transformers = [TSNE(2), PCA(2)] + for idx, transformer in enumerate(transformers): + transformed = transformer.fit_transform(self.z) + + colored = self.colormap(labels) + ax = axs[idx] + ax.scatter(x=transformed[:, 0], y=transformed[:, 1], c=colored) + ax.set_title(transformer.__class__.__name__) + ax.set_xlim(np.min(transformed[:, 0])*1.1, np.max(transformed[:, 0]*1.1)) + ax.set_ylim(np.min(transformed[:, 1]*1.1), np.max(transformed[:, 1]*1.1)) + elif self.z.shape[-1] == 2: + fig, axs = plt.subplots() + + # TODO: Build transformation for lat_dim_size >= 3 + print('All Predictions sucesfully Gathered and Shaped ') + axs.set_xlim(np.min(self.z[:, 0]), np.max(self.z[:, 0])) + axs.set_ylim(np.min(self.z[:, 1]), np.max(self.z[:, 1])) + # ToDo: Insert Normalization + colored = self.colormap(labels) + plt.scatter(self.z[:, 0], self.z[:, 1], c=colored) + else: + raise NotImplementedError("Latent Dimensions can not be one-dimensional (yet).") + + return fig + + def draw_io_bundle(self): + width, height = self._cols * 5, self._rows * 5 + additional_size = self._cols * V.PADDING + 3 * V.PADDING + # width = (self._map_width * self._cols) / V.DPI + additional_size + # height = (self._map_height * self._rows) / V.DPI + additional_size fig = plt.figure(figsize=(width, height), dpi=V.DPI) grid = ImageGrid(fig, 111, # similar to subplot(111) nrows_ncols=(self._rows, self._cols), - axes_pad=padding, # pad between axes in inch. + axes_pad=V.PADDING, # pad between axes in inch. ) for idx in range(len(grid.axes_all)): diff --git a/main.py b/main.py index 84fcbde..5b20fe4 100644 --- a/main.py +++ b/main.py @@ -33,12 +33,13 @@ main_arg_parser.add_argument("--main_seed", type=int, default=69, help="") # Data Parameters main_arg_parser.add_argument("--data_worker", type=int, default=10, help="") -main_arg_parser.add_argument("--data_dataset_length", type=int, default=100000, help="") +main_arg_parser.add_argument("--data_dataset_length", type=int, default=10000, help="") main_arg_parser.add_argument("--data_root", type=str, default='data', help="") main_arg_parser.add_argument("--data_map_root", type=str, default='res/shapes', help="") main_arg_parser.add_argument("--data_normalized", type=strtobool, default=True, help="") main_arg_parser.add_argument("--data_use_preprocessed", type=strtobool, default=True, help="") +main_arg_parser.add_argument("--data_mode", type=str, default='ae_no_label_in_map', help="") # Transformations main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, default=False, help="") @@ -46,7 +47,7 @@ main_arg_parser.add_argument("--transformations_to_tensor", type=strtobool, defa # Transformations main_arg_parser.add_argument("--train_outpath", type=str, default="output", help="") main_arg_parser.add_argument("--train_version", type=strtobool, required=False, help="") -main_arg_parser.add_argument("--train_epochs", type=int, default=20, help="") +main_arg_parser.add_argument("--train_epochs", type=int, default=200, help="") main_arg_parser.add_argument("--train_batch_size", type=int, default=164, help="") main_arg_parser.add_argument("--train_lr", type=float, default=0.002, help="") main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0, help="") @@ -54,9 +55,9 @@ main_arg_parser.add_argument("--train_num_sanity_val_steps", type=int, default=0 # Model main_arg_parser.add_argument("--model_type", type=str, default="CNNRouteGenerator", help="") main_arg_parser.add_argument("--model_activation", type=str, default="elu", help="") -main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 64]", help="") +main_arg_parser.add_argument("--model_filters", type=str, default="[16, 32, 32]", help="") main_arg_parser.add_argument("--model_classes", type=int, default=2, help="") -main_arg_parser.add_argument("--model_lat_dim", type=int, default=8, help="") +main_arg_parser.add_argument("--model_lat_dim", type=int, default=4, help="") main_arg_parser.add_argument("--model_use_bias", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_norm", type=strtobool, default=True, help="") main_arg_parser.add_argument("--model_use_res_net", type=strtobool, default=False, help="") @@ -101,7 +102,7 @@ def run_lightning_loop(config_obj): model.init_weights(torch.nn.init.xavier_normal_) if model.name == 'CNNRouteGeneratorDiscriminated': # ToDo: Make this dependent on the used seed - path = Path(Path(config_obj.train.outpath) / 'classifier_cnn' / 'version_0') + path = logger.outpath / 'classifier_cnn' / 'version_0' disc_model = SavedLightningModels.load_checkpoint(path).restore() model.set_discriminator(disc_model) @@ -111,13 +112,12 @@ def run_lightning_loop(config_obj): show_progress_bar=True, weights_save_path=logger.log_dir, gpus=[0] if torch.cuda.is_available() else None, - check_val_every_n_epoch=1, - num_sanity_val_steps=config_obj.train.num_sanity_val_steps, + check_val_every_n_epoch=10, + # num_sanity_val_steps=config_obj.train.num_sanity_val_steps, # row_log_interval=(model.n_train_batches * 0.1), # TODO: Better Value / Setting # log_save_interval=(model.n_train_batches * 0.2), # TODO: Better Value / Setting checkpoint_callback=checkpoint_callback, logger=logger, - val_percent_check=0.025, fast_dev_run=config_obj.main.debug, early_stop_callback=None ) diff --git a/res/shapes/inverted_1.bmp b/res/shapes/inverted_1.bmp deleted file mode 100644 index 6686a62..0000000 Binary files a/res/shapes/inverted_1.bmp and /dev/null differ diff --git a/res/shapes/inverted_10.bmp b/res/shapes/inverted_10.bmp deleted file mode 100644 index 5a9d012..0000000 Binary files a/res/shapes/inverted_10.bmp and /dev/null differ diff --git a/res/shapes/inverted_2.bmp b/res/shapes/inverted_2.bmp deleted file mode 100644 index 202b327..0000000 Binary files a/res/shapes/inverted_2.bmp and /dev/null differ diff --git a/res/shapes/inverted_3.bmp b/res/shapes/inverted_3.bmp deleted file mode 100644 index a7d945f..0000000 Binary files a/res/shapes/inverted_3.bmp and /dev/null differ diff --git a/res/shapes/inverted_4.bmp b/res/shapes/inverted_4.bmp deleted file mode 100644 index 4aad5c7..0000000 Binary files a/res/shapes/inverted_4.bmp and /dev/null differ diff --git a/res/shapes/inverted_5.bmp b/res/shapes/inverted_5.bmp deleted file mode 100644 index 4eb661e..0000000 Binary files a/res/shapes/inverted_5.bmp and /dev/null differ diff --git a/res/shapes/inverted_6.bmp b/res/shapes/inverted_6.bmp deleted file mode 100644 index cfc01f4..0000000 Binary files a/res/shapes/inverted_6.bmp and /dev/null differ diff --git a/res/shapes/inverted_7.bmp b/res/shapes/inverted_7.bmp deleted file mode 100644 index e0ac1ab..0000000 Binary files a/res/shapes/inverted_7.bmp and /dev/null differ diff --git a/res/shapes/inverted_8.bmp b/res/shapes/inverted_8.bmp deleted file mode 100644 index 8674237..0000000 Binary files a/res/shapes/inverted_8.bmp and /dev/null differ diff --git a/res/shapes/inverted_9.bmp b/res/shapes/inverted_9.bmp deleted file mode 100644 index 551e404..0000000 Binary files a/res/shapes/inverted_9.bmp and /dev/null differ diff --git a/res/shapes/shapes_1.bmp b/res/shapes/shapes_1.bmp deleted file mode 100644 index 8905ad5..0000000 Binary files a/res/shapes/shapes_1.bmp and /dev/null differ diff --git a/res/shapes/shapes_10.bmp b/res/shapes/shapes_10.bmp deleted file mode 100644 index 49ff557..0000000 Binary files a/res/shapes/shapes_10.bmp and /dev/null differ diff --git a/res/shapes/shapes_2.bmp b/res/shapes/shapes_2.bmp deleted file mode 100644 index d01e391..0000000 Binary files a/res/shapes/shapes_2.bmp and /dev/null differ diff --git a/res/shapes/shapes_3.bmp b/res/shapes/shapes_3.bmp deleted file mode 100644 index 52c6d7e..0000000 Binary files a/res/shapes/shapes_3.bmp and /dev/null differ diff --git a/res/shapes/shapes_3.png b/res/shapes/shapes_3.png deleted file mode 100644 index a93d81e..0000000 Binary files a/res/shapes/shapes_3.png and /dev/null differ diff --git a/res/shapes/shapes_4.bmp b/res/shapes/shapes_4.bmp deleted file mode 100644 index 46a848b..0000000 Binary files a/res/shapes/shapes_4.bmp and /dev/null differ diff --git a/res/shapes/shapes_5.bmp b/res/shapes/shapes_5.bmp deleted file mode 100644 index 44cd691..0000000 Binary files a/res/shapes/shapes_5.bmp and /dev/null differ diff --git a/res/shapes/shapes_6.bmp b/res/shapes/shapes_6.bmp deleted file mode 100644 index 6bc8c11..0000000 Binary files a/res/shapes/shapes_6.bmp and /dev/null differ diff --git a/res/shapes/shapes_7.bmp b/res/shapes/shapes_7.bmp deleted file mode 100644 index 4e6bd4d..0000000 Binary files a/res/shapes/shapes_7.bmp and /dev/null differ diff --git a/res/shapes/shapes_8.bmp b/res/shapes/shapes_8.bmp deleted file mode 100644 index e04ea55..0000000 Binary files a/res/shapes/shapes_8.bmp and /dev/null differ diff --git a/res/shapes/shapes_9.bmp b/res/shapes/shapes_9.bmp deleted file mode 100644 index a890335..0000000 Binary files a/res/shapes/shapes_9.bmp and /dev/null differ