hom_traj_gen/datasets/paired_dataset.py
2020-02-21 09:44:09 +01:00

101 lines
3.6 KiB
Python

import shelve
from pathlib import Path
from typing import Union
import torch
from torch.utils.data import Dataset, ConcatDataset
from datasets.utils import DatasetMapping
from lib.modules.model_parts import Generator
from lib.objects.map import Map
class TrajPairDataset(Dataset):
@property
def map_shape(self):
return self.map.as_array.shape
def __init__(self, data):
super(TrajPairDataset, self).__init__()
self.alternatives = data['alternatives']
self.trajectory = data['trajectory']
self.labels = data['labels']
self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name']
self.map = data['map']['map']
def __len__(self):
return len(self.alternatives)
def __getitem__(self, item):
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
class TrajPairData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
return map(max, zip(*shapes))
@property
def name(self):
return self.__class__.__name__
def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10,
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
self.rebuild = rebuild
self.equal_samples = equal_samples
self._alternatives = alternatives
self._trajectories = trajectories
self.mapname = mapname
self.train_split, self.val_split, self.test_split = train_val_test_split
self.data_root = Path(data_root)
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
def _build_data_on_demand(self):
map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
assert self.maps_root.exists()
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
if dataset_file.exists() and self.rebuild:
dataset_file.unlink()
if not dataset_file.exists():
generator = Generator(self.data_root, map_object)
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
self.mapname, equal_samples=self.equal_samples)
return True
def _load_dataset(self):
assert self._build_data_on_demand()
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map'])
indices = torch.randperm(len(dataset))
train_size = int(len(dataset) * self.train_split)
val_size = int(len(dataset) * self.val_split)
test_size = int(len(dataset) * self.test_split)
train_map = indices[:train_size]
val_map = indices[train_size:val_size]
test_map = indices[test_size:]
return dataset, train_map, val_map, test_map
@property
def train_dataset(self):
return DatasetMapping(self._dataset, self._train_map)
@property
def val_dataset(self):
return DatasetMapping(self._dataset, self._val_map)
@property
def test_dataset(self):
return DatasetMapping(self._dataset, self._test_map)
def get_datasets(self):
return self.train_dataset, self.val_dataset, self.test_dataset