128 lines
4.4 KiB
Python
128 lines
4.4 KiB
Python
import shelve
|
|
from pathlib import Path
|
|
from typing import Union, List
|
|
|
|
import torch
|
|
from random import choice
|
|
from torch.utils.data import ConcatDataset, Dataset
|
|
import numpy as np
|
|
|
|
from lib.objects.map import Map
|
|
import lib.variables as V
|
|
from PIL import Image
|
|
|
|
|
|
class TrajDataset(Dataset):
|
|
|
|
@property
|
|
def map_shape(self):
|
|
return self.map.as_array.shape
|
|
|
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
|
|
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
|
**kwargs):
|
|
super(TrajDataset, self).__init__()
|
|
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
|
|
self.normalized = normalized
|
|
self.preserve_equal_samples = preserve_equal_samples
|
|
self.mode = mode
|
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
|
self.maps_root = maps_root
|
|
self._len = length
|
|
self.last_label = -1
|
|
|
|
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
|
|
|
def __len__(self):
|
|
return self._len
|
|
|
|
def __getitem__(self, item):
|
|
|
|
if self.mode.lower() == 'just_route':
|
|
trajectory = self.map.get_random_trajectory()
|
|
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
|
label = choice([0, 1])
|
|
map_array = torch.as_tensor(self.map.as_array).float()
|
|
return (map_array, trajectory_space), label
|
|
|
|
while True:
|
|
trajectory = self.map.get_random_trajectory()
|
|
alternative = self.map.generate_alternative(trajectory)
|
|
label = self.map.are_homotopic(trajectory, alternative)
|
|
if self.preserve_equal_samples and label == self.last_label:
|
|
continue
|
|
else:
|
|
break
|
|
|
|
self.last_label = label
|
|
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
|
map_array = self.map.as_array
|
|
trajectory = trajectory.draw_in_array(self.map_shape)
|
|
alternative = alternative.draw_in_array(self.map_shape)
|
|
if self.mode == 'separated_arrays':
|
|
if self.normalized:
|
|
map_array = map_array / V.WHITE
|
|
trajectory = trajectory / V.WHITE
|
|
alternative = alternative / V.WHITE
|
|
return (map_array, trajectory, label), alternative
|
|
else:
|
|
return np.concatenate((map_array, trajectory, alternative)), label
|
|
|
|
elif self.mode == 'vectors':
|
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
|
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
class TrajData(object):
|
|
@property
|
|
def map_shapes(self):
|
|
return [dataset.map_shape for dataset in self._dataset.datasets]
|
|
|
|
@property
|
|
def map_shapes_max(self):
|
|
shapes = self.map_shapes
|
|
shape_list = list(map(max, zip(*shapes)))
|
|
if self.mode in ['separated_arrays', 'all_in_map']:
|
|
shape_list[0] += 2
|
|
return shape_list
|
|
|
|
@property
|
|
def name(self):
|
|
return self.__class__.__name__
|
|
|
|
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
|
|
|
|
self.normalized = normalized
|
|
self.mode = mode
|
|
self.maps_root = Path(map_root)
|
|
self.length = length
|
|
self._dataset = self._load_datasets()
|
|
|
|
def _load_datasets(self):
|
|
map_files = list(self.maps_root.glob('*.bmp'))
|
|
equal_split = int(self.length // len(map_files)) or 1
|
|
|
|
# find max image size among available maps:
|
|
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
|
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
|
preserve_equal_samples=True)
|
|
for map_file in map_files])
|
|
|
|
@property
|
|
def train_dataset(self):
|
|
return self._dataset
|
|
|
|
@property
|
|
def val_dataset(self):
|
|
return self._dataset
|
|
|
|
@property
|
|
def test_dataset(self):
|
|
return self._dataset
|
|
|
|
def get_datasets(self):
|
|
return self._dataset, self._dataset, self._dataset
|