import shelve from collections import defaultdict from pathlib import Path from typing import Union from torchvision.transforms import Normalize import multiprocessing as mp import torch from random import choice from torch.utils.data import ConcatDataset, Dataset import numpy as np from tqdm import tqdm from lib.objects.map import Map import lib.variables as V from PIL import Image from lib.utils.tools import write_to_shelve class TrajDataShelve(Dataset): @property def map_shape(self): return self[0][0].shape def __init__(self, file_path, **kwargs): assert Path(file_path).exists() super(TrajDataShelve, self).__init__() self._mutex = mp.Lock() self.file_path = str(file_path) def __len__(self): self._mutex.acquire() with shelve.open(self.file_path) as d: length = len(d) d.close() self._mutex.release() return length def seed(self): pass def __getitem__(self, item): self._mutex.acquire() with shelve.open(self.file_path) as d: sample = d[str(item)] d.close() self._mutex.release() return sample class TrajDataset(Dataset): @property def _last_label_init(self): d = defaultdict(lambda: -1) d['generator_hom_all_in_map'] = V.ALTERNATIVE d['generator_alt_all_in_map'] = V.HOMOTOPIC return d[self.mode] @property def map_shape(self): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True, length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map', 'ae_no_label_in_map', 'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map'] self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x self.preserve_equal_samples = preserve_equal_samples self.mode = mode self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length self.last_label = self._last_label_init self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size) def __len__(self): return self._len def __getitem__(self, item): if self.mode.lower() == 'just_route': raise NotImplementedError trajectory = self.map.get_random_trajectory() trajectory_space = trajectory.draw_in_array(self.map.shape) label = choice([0, 1]) map_array = torch.as_tensor(self.map.as_array).float() return (map_array, trajectory_space), label # Produce an alternative. while True: trajectory = self.map.get_random_trajectory() alternative = self.map.generate_alternative(trajectory) label = self.map.are_homotopic(trajectory, alternative) if self.preserve_equal_samples and label == self.last_label: continue else: break self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode] if 'in_map' in self.mode.lower(): map_array = self.map.as_array trajectory = trajectory.draw_in_array(self.map_shape) alternative = alternative.draw_in_array(self.map_shape) label_as_array = np.full_like(map_array, label) if self.mode == 'generator_all_in_map': return np.concatenate((map_array, trajectory, label_as_array)), alternative elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']: return np.sum((map_array, trajectory, alternative), axis=0), 0 elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']: return np.concatenate((map_array, trajectory)), alternative elif self.mode == 'classifier_all_in_map': return np.concatenate((map_array, trajectory, alternative)), label elif self.mode == '_vectors': raise NotImplementedError return trajectory.vertices, alternative.vertices, label, self.mapname raise ValueError(f'Mode was: {self.mode}') def seed(self, seed): self.map.seed(seed) class TrajData(object): @property def map_shapes(self): return [dataset.map_shape for dataset in self.train_dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) if '_all_in_map' in self.mode and not self.preprocessed: shape_list[0] += 2 return shape_list @property def name(self): return self.__class__.__name__ def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_): self.preprocessed = preprocessed self.normalized = normalized self.mode = mode self.maps_root = Path(map_root) self.length = length self.test_dataset = self._load_datasets('test') self.val_dataset = self._load_datasets('val') self.train_dataset = self._load_datasets('train') def _load_datasets(self, dataset_type=''): map_files = list(self.maps_root.glob('*.bmp')) # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) if self.preprocessed: preprocessed_map_files = list(self.maps_root.glob('*.pik')) preprocessed_map_names = [p.name for p in preprocessed_map_files] datasets = [] for map_file in map_files: equal_split = int(self.length // len(map_files)) or 5 new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik' if dataset_type != 'train': equal_split = max(int(equal_split * 0.01), 10) if not new_pik_name in preprocessed_map_names: traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, mode=self.mode, embedding_size=max_map_size, normalized=self.normalized, preserve_equal_samples=True) self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split) dataset = TrajDataShelve(map_file.parent / new_pik_name) datasets.append(dataset) return ConcatDataset(datasets) # Set the equal split so that all maps are visited with the same frequency equal_split = int(self.length // len(map_files)) or 5 return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, mode=self.mode, embedding_size=max_map_size, normalized=self.normalized, preserve_equal_samples=True) for map_file in map_files]) def kill_em_all(self): for pik_file in self.maps_root.glob('*.pik'): pik_file.unlink() print(pik_file.name, ' was deleted.') print('Done.') def seed(self, seed): for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]: dataset.seed(seed) def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000): assert str(file_path).endswith('.pik') mutex = mp.Lock() for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'): sample = traj_dataset[i] mutex.acquire() write_to_shelve(file_path, sample) mutex.release() print(f'{n} samples successfully dumped to "{file_path}"!') def get_datasets(self): return self._train_dataset, self._val_dataset, self._test_dataset