97 lines
3.4 KiB
Python
97 lines
3.4 KiB
Python
import shelve
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from torch.utils.data import ConcatDataset, Dataset
|
|
|
|
from lib.objects.map import Map
|
|
from preprocessing.generator import Generator
|
|
|
|
|
|
class TrajDataset(Dataset):
|
|
def __init__(self, data):
|
|
super(TrajDataset, self).__init__()
|
|
self.alternatives = data['alternatives']
|
|
self.trajectory = data['trajectory']
|
|
self.labels = data['labels']
|
|
|
|
def __len__(self):
|
|
return len(self.alternatives)
|
|
|
|
def __getitem__(self, item):
|
|
return self.trajectory.vertices, self.alternatives[item].vertices, self.labels[item]
|
|
|
|
|
|
class DataSetMapping(Dataset):
|
|
def __init__(self, dataset, mapping):
|
|
self._dataset = dataset
|
|
self._mapping = mapping
|
|
|
|
def __len__(self):
|
|
return self._mapping.shape[0]
|
|
|
|
def __getitem__(self, item):
|
|
return self._dataset[self._mapping[item]]
|
|
|
|
|
|
class TrajData(object):
|
|
@property
|
|
def name(self):
|
|
return self.__class__.__name__
|
|
|
|
def __init__(self, data_root, mapname='tate_sw', trajectories=1000, alternatives=10,
|
|
train_val_test_split=(0.6, 0.2, 0.2), rebuild=False, equal_samples=True, **_):
|
|
|
|
self.rebuild = rebuild
|
|
self.equal_samples = equal_samples
|
|
self._alternatives = alternatives
|
|
self._trajectories = trajectories
|
|
self.mapname = mapname
|
|
self.train_split, self.val_split, self.test_split = train_val_test_split
|
|
self.data_root = Path(data_root)
|
|
self._dataset = None
|
|
self._dataset, self._train_map, self._val_map, self._test_map = self._load_dataset()
|
|
|
|
def _build_data_on_demand(self):
|
|
maps_root = Path() / 'res' / 'maps'
|
|
map_object = Map(self.mapname).from_image(maps_root / f'{self.mapname}.bmp')
|
|
assert maps_root.exists()
|
|
dataset_file = Path(self.data_root) / f'{self.mapname}.pik'
|
|
if dataset_file.exists() and self.rebuild:
|
|
dataset_file.unlink()
|
|
if not dataset_file.exists():
|
|
generator = Generator(self.data_root, map_object)
|
|
generator.generate_n_trajectories_m_alternatives(self._trajectories, self._alternatives,
|
|
self.mapname, equal_samples=self.equal_samples)
|
|
return True
|
|
|
|
def _load_dataset(self):
|
|
assert self._build_data_on_demand()
|
|
with shelve.open(str(self.data_root / f'{self.mapname}.pik')) as d:
|
|
dataset = ConcatDataset([TrajDataset(d[key]) for key in d.keys() if key != 'map'])
|
|
indices = torch.randperm(len(dataset))
|
|
|
|
train_size = int(len(dataset) * self.train_split)
|
|
val_size = int(len(dataset) * self.val_split)
|
|
test_size = int(len(dataset) * self.test_split)
|
|
|
|
train_map = indices[:train_size]
|
|
val_map = indices[train_size:val_size]
|
|
test_map = indices[test_size:]
|
|
return dataset, train_map, val_map, test_map
|
|
|
|
@property
|
|
def train_dataset(self):
|
|
return DataSetMapping(self._dataset, self._train_map)
|
|
|
|
@property
|
|
def val_dataset(self):
|
|
return DataSetMapping(self._dataset, self._val_map)
|
|
|
|
@property
|
|
def test_dataset(self):
|
|
return DataSetMapping(self._dataset, self._test_map)
|
|
|
|
def get_datasets(self):
|
|
return self.train_dataset, self.val_dataset, self.test_dataset
|