import shelve from pathlib import Path import torch from torch.utils.data import ConcatDataset, Dataset from lib.objects.map import Map from preprocessing.generator import Generator class TrajDataset(Dataset): def __init__(self, data): super(TrajDataset, self).__init__() self.alternatives = data['alternatives'] self.trajectory = data['trajectory'] self.labels = data['labels'] def __len__(self): return len(self.alternatives) def __getitem__(self, item): return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item] class DataSetMapping(Dataset): def __init__(self, dataset, mapping): self._dataset = dataset self._mapping = mapping def __len__(self): return self._mapping.shape[0] def __getitem__(self, item): return self._dataset[self._mapping[item]] class TrajData(object): @property def name(self): return self.__class__.__name__ def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10, train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_): self.rebuild = rebuild self.equal_samples = equal_samples self._alternatives = alternatives self._trajectories = trajectories self.mapname = mapname self.train_split, self.val_split, self.test_split = train_val_test_split self.data_root = Path(data_root) self._dataset = None self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset() def _build_data_on_demand(self): maps_root = Path() / 'res' / 'maps' map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp') assert maps_root.exists() dataset_file = Path(self.data_root) / f'{self.mapname}.pik' if dataset_file.exists() and self.rebuild: dataset_file.unlink() if not dataset_file.exists(): generator = Generator(self.data_root, map_object) generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives, self.mapname, equal_samples=self.equal_samples) return True def _load_dataset(self): assert self._build_data_on_demand() with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d: dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map']) indices = torch.randperm(len(dataset)) train_size = int(len(dataset) * self.train_split) val_size = int(len(dataset) * self.val_split) test_size = int(len(dataset) * self.test_split) train_map = indices[:train_size] val_map = indices[train_size:val_size] test_map = indices[test_size:] return dataset, train_map, val_map, test_map @property def train_dataset(self): return DataSetMapping(self._dataset, self._train_map) @property def val_dataset(self): return DataSetMapping(self._dataset, self._val_map) @property def test_dataset(self): return DataSetMapping(self._dataset, self._test_map) def get_datasets(self): return self.train_dataset, self.val_dataset, self.test_dataset