Files
hom_traj_gen/datasets/trajectory_dataset.py
2020-03-08 23:46:02 +01:00

128 lines
4.4 KiB
Python

import shelve
from pathlib import Path
from typing import Union, List
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
from lib.objects.map import Map
from PIL import Image
class TrajDataset(Dataset):
@property
def map_shape(self):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw',
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False, **kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
self.preserve_equal_samples = preserve_equal_samples
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = -1
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
def __len__(self):
return self._len
def __getitem__(self, item):
if self.mode.lower() == 'just_route':
trajectory = self.map.get_random_trajectory()
label = choice([0, 1])
blank_trajectory_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, blank_trajectory_space), label
while True:
trajectory = self.map.get_random_trajectory()
# TODO: Sanity Check this while true loop...
alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
continue
else:
break
self.last_label = label
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
blank_trajectory_space = torch.zeros(self.map.shape)
blank_alternative_space = torch.zeros(self.map.shape)
for index in trajectory.vertices:
blank_trajectory_space[index] = 1
for index in alternative.vertices:
blank_alternative_space[index] = 1
map_array = torch.as_tensor(self.map.as_array).float()
if self.mode == 'separated_arrays':
return (map_array, blank_trajectory_space, int(label)), blank_alternative_space
else:
return torch.cat((map_array, blank_trajectory_space, blank_alternative_space)), int(label)
elif self.mode == 'vectors':
return trajectory.vertices, alternative.vertices, label, self.mapname
else:
raise ValueError
class TrajData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if self.mode == 'all_in_map':
shape_list[0] += 2
return shape_list
@property
def name(self):
return self.__class__.__name__
def __init__(self, map_root, length=100000, mode='separated_arrays', **_):
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._dataset = self._load_datasets()
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files))
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size,
preserve_equal_samples=True)
for map_file in map_files])
@property
def train_dataset(self):
return self._dataset
@property
def val_dataset(self):
return self._dataset
@property
def test_dataset(self):
return self._dataset
def get_datasets(self):
return self._dataset, self._dataset, self._dataset