92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
import shelve
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
import torch
|
|
from random import choice
|
|
from torch.utils.data import ConcatDataset, Dataset
|
|
|
|
from lib.objects.map import Map
|
|
from lib.preprocessing.generator import Generator
|
|
|
|
|
|
class TrajDataset(Dataset):
|
|
|
|
@property
|
|
def map_shape(self):
|
|
return self.map.as_array.shape
|
|
|
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
|
|
length=100.000, all_in_map=True, **kwargs):
|
|
super(TrajDataset, self).__init__()
|
|
self.all_in_map = all_in_map
|
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
|
self.maps_root = maps_root
|
|
self._len = length
|
|
|
|
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname)
|
|
|
|
def __len__(self):
|
|
return self._len
|
|
|
|
def __getitem__(self, item):
|
|
trajectory = self.map.get_random_trajectory()
|
|
alternative = self.map.generate_alternative(trajectory)
|
|
label = choice([0, 1])
|
|
if self.all_in_map:
|
|
blank_trajectory_space = torch.zeros_like(self.map.as_array)
|
|
blank_trajectory_space[trajectory.vertices] = 1
|
|
|
|
blank_alternative_space = torch.zeros_like(self.map.as_array)
|
|
blank_alternative_space[trajectory.vertices] = 1
|
|
|
|
map_array = torch.as_tensor(self.map.as_array)
|
|
label = self.map.are_homotopic(trajectory, alternative)
|
|
|
|
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), label
|
|
else:
|
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
|
|
|
|
|
class TrajData(object):
|
|
@property
|
|
def map_shapes(self):
|
|
return [dataset.map_shape for dataset in self._dataset.datasets]
|
|
|
|
@property
|
|
def map_shapes_max(self):
|
|
shapes = self.map_shapes
|
|
return map(max, zip(*shapes))
|
|
|
|
@property
|
|
def name(self):
|
|
return self.__class__.__name__
|
|
|
|
def __init__(self, *args, map_root: Union[Path, str] = '', length=100.000, all_in_map=True, **_):
|
|
|
|
self.all_in_map = all_in_map
|
|
self.maps_root = Path(map_root) if map_root else Path() / 'res' / 'maps'
|
|
self._dataset = self._load_datasets()
|
|
self.length = length
|
|
|
|
def _load_datasets(self):
|
|
map_files = list(self.maps_root.glob('*.bmp'))
|
|
equal_split = self.length // len(map_files)
|
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_image.name, length=equal_split,
|
|
all_in_map=self.all_in_map) for map_image in map_files])
|
|
|
|
@property
|
|
def train_dataset(self):
|
|
return self._dataset
|
|
|
|
@property
|
|
def val_dataset(self):
|
|
return self._dataset
|
|
|
|
@property
|
|
def test_dataset(self):
|
|
return self._dataset
|
|
|
|
def get_datasets(self):
|
|
return self._dataset, self._dataset, self._dataset
|