import shelve from pathlib import Path from typing import Union, List import torch from random import choice from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map from PIL import Image class TrajDataset(Dataset): @property def map_shape(self): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100000, all_in_map=True, embedding_size=None, **kwargs): super(TrajDataset, self).__init__() self.all_in_map = all_in_map self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) def __len__(self): return self._len def __getitem__(self, item): trajectory = self.map.get_random_trajectory() alternative = self.map.generate_alternative(trajectory) label = choice([0, 1]) if self.all_in_map: blank_trajectory_space = torch.zeros(self.map.shape) blank_alternative_space = torch.zeros(self.map.shape) for index in trajectory.vertices: blank_trajectory_space[index] = 1 blank_alternative_space[index] = 1 map_array = torch.as_tensor(self.map.as_array).float() label = self.map.are_homotopic(trajectory, alternative) return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label) else: return trajectory.vertices, alternative.vertices, label, self.mapname class TrajData(object): @property def map_shapes(self): return [dataset.map_shape for dataset in self._dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) if self.all_in_map: shape_list[0] += 2 return shape_list @property def name(self): return self.__class__.__name__ def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_): self.all_in_map = all_in_map self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps' self.length = length self._dataset = self._load_datasets() def _load_datasets(self): map_files = list(self.maps_root.glob('*.bmp')) equal_split = int(self.length // len(map_files)) # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, all_in_map=self.all_in_map, embedding_size=max_map_size) for map_file in map_files]) @property def train_dataset(self): return self._dataset @property def val_dataset(self): return self._dataset @property def test_dataset(self): return self._dataset def get_datasets(self): return self._dataset, self._dataset, self._dataset