import shelve from pathlib import Path from typing import Union import torch from random import choice from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map from lib.preprocessing.generator import Generator class TrajDataset(Dataset): @property def map_shape(self): return self.map.as_array.shape def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100.000, all_in_map=True, **kwargs): super(TrajDataset, self).__init__() self.all_in_map = all_in_map self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp' self.maps_root = maps_root self._len = length self.map = Map(self.mapname).from_image(self.maps_root / self.mapname) def __len__(self): return self._len def __getitem__(self, item): trajectory = self.map.get_random_trajectory() alternative = self.map.generate_alternative(trajectory) label = choice([0, 1]) if self.all_in_map: blank_trajectory_space = torch.zeros_like(self.map.as_array) blank_trajectory_space[trajectory.vertices] = 1 blank_alternative_space = torch.zeros_like(self.map.as_array) blank_alternative_space[trajectory.vertices] = 1 map_array = torch.as_tensor(self.map.as_array) label = self.map.are_homotopic(trajectory, alternative) return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), label else: return trajectory.vertices, alternative.vertices, label, self.mapname class TrajData(object): @property def map_shapes(self): return [dataset.map_shape for dataset in self._dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes return map(max, zip(*shapes)) @property def name(self): return self.__class__.__name__ def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_): self.all_in_map = all_in_map self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps' self._dataset = self._load_datasets() self.length = length def _load_datasets(self): map_files = list(self.maps_root.glob('*.bmp')) equal_split = self.length // len(map_files) return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split, all_in_map=self.all_in_map) for map_image in map_files]) @property def train_dataset(self): return self._dataset @property def val_dataset(self): return self._dataset @property def test_dataset(self): return self._dataset def get_datasets(self): return self._dataset, self._dataset, self._dataset