import shelve from pathlib import Path from typing import Union, List import torch from random import choice from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map from PIL import Image class TrajDataset(Dataset): @property def map_shape(self): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route'] self.preserve_equal_samples = preserve_equal_samples self.mode = mode self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length self.last_label = -1 self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) def __len__(self): return self._len def __getitem__(self, item): if self.mode.lower() == 'just_route': trajectory = self.map.get_random_trajectory() label = choice([0, 1]) blank_trajectory_space = torch.zeros(self.map.shape) for index in trajectory.vertices: blank_trajectory_space[index] = 1 map_array = torch.as_tensor(self.map.as_array).float() return (map_array, blank_trajectory_space), label while True: trajectory = self.map.get_random_trajectory() # TODO: Sanity Check this while true loop... alternative = self.map.generate_alternative(trajectory) label = self.map.are_homotopic(trajectory, alternative) if self.preserve_equal_samples and label == self.last_label: continue else: break self.last_label = label if self.mode.lower() in ['all_in_map', 'separated_arrays']: blank_trajectory_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape) for index in trajectory.vertices: blank_trajectory_space[index] = 1 for index in alternative.vertices: blank_alternative_space[index] = 1 map_array = torch.as_tensor(self.map.as_array).float() if self.mode == 'separated_arrays': return (map_array, blank_trajectory_space, int(label)), blank_alternative_space else: return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) elif self.mode == 'vectors': return trajectory.vertices, alternative.vertices, label, self.mapname else: raise ValueError class TrajData(object): @property def map_shapes(self): return [dataset.map_shape for dataset in self._dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) if self.mode == 'all_in_map': shape_list[0] += 2 return shape_list @property def name(self): return self.__class__.__name__ def __init__(self, map_root, length=100000, mode='separated_arrays', **_): self.mode = mode self.maps_root = Path(map_root) self.length = length self._dataset = self._load_datasets() def _load_datasets(self): map_files = list(self.maps_root.glob('*.bmp')) equal_split = int(self.length // len(map_files)) # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, mode=self.mode, embedding_size=max_map_size, preserve_equal_samples=True) for map_file in map_files]) @property def train_dataset(self): return self._dataset @property def val_dataset(self): return self._dataset @property def test_dataset(self): return self._dataset def get_datasets(self): return self._dataset, self._dataset, self._dataset