import shelve from pathlib import Path from typing import Union, List import torch from random import choice from torch.utils.data import ConcatDataset, Dataset import numpy as np from lib.objects.map import Map import lib.variables as V from PIL import Image class TrajDataset(Dataset): @property def map_shape(self): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True, length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs): super(TrajDataset, self).__init__() assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route'] self.normalized = normalized self.preserve_equal_samples = preserve_equal_samples self.mode = mode self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length self.last_label = -1 self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size) def __len__(self): return self._len def __getitem__(self, item): if self.mode.lower() == 'just_route': trajectory = self.map.get_random_trajectory() trajectory_space = trajectory.draw_in_array(self.map.shape) label = choice([0, 1]) map_array = torch.as_tensor(self.map.as_array).float() return (map_array, trajectory_space), label while True: trajectory = self.map.get_random_trajectory() alternative = self.map.generate_alternative(trajectory) label = self.map.are_homotopic(trajectory, alternative) if self.preserve_equal_samples and label == self.last_label: continue else: break self.last_label = label if self.mode.lower() in ['all_in_map', 'separated_arrays']: map_array = self.map.as_array trajectory = trajectory.draw_in_array(self.map_shape) alternative = alternative.draw_in_array(self.map_shape) if self.mode == 'separated_arrays': if self.normalized: map_array = map_array / V.WHITE trajectory = trajectory / V.WHITE alternative = alternative / V.WHITE return (map_array, trajectory, label), alternative else: return np.concatenate((map_array, trajectory, alternative)), label elif self.mode == 'vectors': return trajectory.vertices, alternative.vertices, label, self.mapname else: raise ValueError class TrajData(object): @property def map_shapes(self): return [dataset.map_shape for dataset in self._dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes shape_list = list(map(max, zip(*shapes))) if self.mode in ['separated_arrays', 'all_in_map']: shape_list[0] += 2 return shape_list @property def name(self): return self.__class__.__name__ def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_): self.normalized = normalized self.mode = mode self.maps_root = Path(map_root) self.length = length self._dataset = self._load_datasets() def _load_datasets(self): map_files = list(self.maps_root.glob('*.bmp')) equal_split = int(self.length // len(map_files)) or 1 # find max image size among available maps: max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files])))) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split, mode=self.mode, embedding_size=max_map_size, normalized=self.normalized, preserve_equal_samples=True) for map_file in map_files]) @property def train_dataset(self): return self._dataset @property def val_dataset(self): return self._dataset @property def test_dataset(self): return self._dataset def get_datasets(self): return self._dataset, self._dataset, self._dataset