hom_traj_gen/dataset/dataset.py

135 lines
4.4 KiB
Python

import shelve
from pathlib import Path
from typing import Union
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
from lib.preprocessing.generator import Generator
class TrajPairDataset(Dataset):
def __init__(self, data):
super(TrajPairDataset, self).__init__()
self.alternatives = data['alternatives']
self.trajectory = data['trajectory']
self.labels = data['labels']
self.mapname = data['map']['name'][4:] if data['map']['name'].startswith('map_') else data['map']['name']
def __len__(self):
return len(self.alternatives)
def __getitem__(self, item):
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item], self.mapname
class DatasetMapping(Dataset):
def __init__(self, dataset, mapping):
self._dataset = dataset
self._mapping = mapping
def __len__(self):
return self._mapping.shape[0]
def __getitem__(self, item):
return self._dataset[self._mapping[item]]
class TrajPairData(object):
@property
def name(self):
return self.__class__.__name__
def __init__(self, data_root, map_root: Union[Path, str] = '', mapname='tate_sw', trajectories=1000, alternatives=10,
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
self.rebuild = rebuild
self.equal_samples = equal_samples
self._alternatives = alternatives
self._trajectories = trajectories
self.mapname = mapname
self.train_split, self.val_split, self.test_split = train_val_test_split
self.data_root = Path(data_root)
self.maps_root = Path(data_root) if data_root else Path() / 'res' / 'maps'
self._dataset = None
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
def _build_data_on_demand(self):
map_object = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
assert self.maps_root.exists()
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
if dataset_file.exists() and self.rebuild:
dataset_file.unlink()
if not dataset_file.exists():
generator = Generator(self.data_root, map_object)
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
self.mapname, equal_samples=self.equal_samples)
return True
def _load_dataset(self):
assert self._build_data_on_demand()
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
dataset = ConcatDataset([TrajPairDataset(d[key]) for key in d.keys() if key != 'map'])
indices = torch.randperm(len(dataset))
train_size = int(len(dataset) * self.train_split)
val_size = int(len(dataset) * self.val_split)
test_size = int(len(dataset) * self.test_split)
train_map = indices[:train_size]
val_map = indices[train_size:val_size]
test_map = indices[test_size:]
return dataset, train_map, val_map, test_map
@property
def train_dataset(self):
return DatasetMapping(self._dataset, self._train_map)
@property
def val_dataset(self):
return DatasetMapping(self._dataset, self._val_map)
@property
def test_dataset(self):
return DatasetMapping(self._dataset, self._test_map)
def get_datasets(self):
return self.train_dataset, self.val_dataset, self.test_dataset
class TrajDataset(Dataset):
def __init__(self, data_root, maps_root: Union[Path, str] = '', mapname='tate_sw', length=100.000, **_):
super(TrajDataset, self).__init__()
self.mapname = mapname
self.maps_root = maps_root
self.data_root = data_root
self._len = length
self._map_obj = Map(self.mapname).from_image(self.maps_root / f'{self.mapname}.bmp')
def __len__(self):
return self._len
def __getitem__(self, item):
trajectory = self._map_obj.get_random_trajectory()
label = choice([0, 1])
return trajectory.vertices, None, label, self.mapname
@property
def train_dataset(self):
return self
@property
def val_dataset(self):
return self
@property
def test_dataset(self):
return self
def get_datasets(self):
return self, self, self