import shelve from pathlib import Path from typing import Union import torch from torch.utils.data import Dataset, ConcatDataset from datasets.utils import DatasetMapping from lib.modules.model_parts import Generator from lib.objects.map import Map class TrajPairDataset(Dataset): @property def map_shape(self): return self.map.as_array.shape def __init__(self, data): super(TrajPairDataset, self).__init__() self.alternatives = data['alternatives'] self.trajectory = data['trajectory'] self.labels = data['labels'] self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name'] self.map = data['map']['map'] def __len__(self): return len(self.alternatives) def __getitem__(self, item): return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname class TrajPairData(object): @property def map_shapes(self): return [dataset.map_shape for dataset in self._dataset.datasets] @property def map_shapes_max(self): shapes = self.map_shapes return map(max, zip(*shapes)) @property def name(self): return self.__class__.__name__ def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10, train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_): self.rebuild = rebuild self.equal_samples = equal_samples self._alternatives = alternatives self._trajectories = trajectories self.mapname = mapname self.train_split, self.val_split, self.test_split = train_val_test_split self.data_root = Path(data_root) self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps' self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset() def _build_data_on_demand(self): map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp') assert self.maps_root.exists() dataset_file = Path(self.data_root) / f'{self.mapname}.pik' if dataset_file.exists() and self.rebuild: dataset_file.unlink() if not dataset_file.exists(): generator = Generator(self.data_root, map_object) generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives, self.mapname, equal_samples=self.equal_samples) return True def _load_dataset(self): assert self._build_data_on_demand() with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d: dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map']) indices = torch.randperm(len(dataset)) train_size = int(len(dataset) * self.train_split) val_size = int(len(dataset) * self.val_split) test_size = int(len(dataset) * self.test_split) train_map = indices[:train_size] val_map = indices[train_size:val_size] test_map = indices[test_size:] return dataset, train_map, val_map, test_map @property def train_dataset(self): return DatasetMapping(self._dataset, self._train_map) @property def val_dataset(self): return DatasetMapping(self._dataset, self._val_map) @property def test_dataset(self): return DatasetMapping(self._dataset, self._test_map) def get_datasets(self): return self.train_dataset, self.val_dataset, self.test_dataset