import shelve from pathlib import Path from typing import Union, List import torch from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map from PIL import Image class TrajDataset(Dataset): @property def map_shape(self): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100000, all_in_map=True, embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() self.preserve_equal_samples = preserve_equal_samples self.all_in_map = all_in_map self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length self.last_label = -1 self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) def __len__(self): return self._len def __getitem__(self, item): trajectory = self.map.get_random_trajectory() while True: # TODO: Sanity Check this while true loop... alternative = self.map.generate_alternative(trajectory) label = self.map.are_homotopic(trajectory, alternative) if self.preserve_equal_samples and label == self.last_label: continue else: break self.last_label = label if self.all_in_map: blank_trajectory_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape) for index in trajectory.vertices: blank_trajectory_space[index] = 1 blank_alternative_space[index] = 1 map_array = torch.as_tensor(self.map.as_array).float() return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) else: return trajectory.vertices, alternative.vertices, label, self.mapname class TrajData(object): @property def map_shapes(self): return [dataset.map_shape for dataset in self._dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) if self.all_in_map: shape_list[0] += 2 return shape_list @property def name(self): return self.__class__.__name__ def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_): self.all_in_map = all_in_map self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps' self.length = length self._dataset = self._load_datasets() def _load_datasets(self): map_files = list(self.maps_root.glob('*.bmp')) equal_split = int(self.length // len(map_files)) # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, all_in_map=self.all_in_map, embedding_size=max_map_size, preserve_equal_samples=True) for map_file in map_files]) @property def train_dataset(self): return self._dataset @property def val_dataset(self): return self._dataset @property def test_dataset(self): return self._dataset def get_datasets(self): return self._dataset, self._dataset, self._dataset