hom_traj_gen/datasets/trajectory_dataset.py
2020-03-13 21:52:33 +01:00

217 lines
8.2 KiB
Python

import shelve
from collections import defaultdict
from pathlib import Path
from typing import Union
from torchvision.transforms import Normalize
import multiprocessing as mp
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
import numpy as np
from tqdm import tqdm
from lib.objects.map import Map
import lib.variables as V
from PIL import Image
from lib.utils.tools import write_to_shelve
class TrajDataShelve(Dataset):
@property
def map_shape(self):
return self[0][0].shape
def __init__(self, file_path, **kwargs):
assert Path(file_path).exists()
super(TrajDataShelve, self).__init__()
self._mutex = mp.Lock()
self.file_path = str(file_path)
def __len__(self):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
length = len(d)
d.close()
self._mutex.release()
return length
def seed(self):
pass
def __getitem__(self, item):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
sample = d[str(item)]
d.close()
self._mutex.release()
return sample
class TrajDataset(Dataset):
@property
def _last_label_init(self):
d = defaultdict(lambda: -1)
d['generator_hom_all_in_map'] = V.ALTERNATIVE
d['generator_alt_all_in_map'] = V.HOMOTOPIC
return d[self.mode]
@property
def map_shape(self):
return self.map.as_array.shape
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
**kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map',
'ae_no_label_in_map',
'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map']
self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x
self.preserve_equal_samples = preserve_equal_samples
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = self._last_label_init
self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
def __len__(self):
return self._len
def __getitem__(self, item):
if self.mode.lower() == 'just_route':
raise NotImplementedError
trajectory = self.map.get_random_trajectory()
trajectory_space = trajectory.draw_in_array(self.map.shape)
label = choice([0, 1])
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, trajectory_space), label
# Produce an alternative.
while True:
trajectory = self.map.get_random_trajectory()
alternative = self.map.generate_alternative(trajectory)
label = self.map.are_homotopic(trajectory, alternative)
if self.preserve_equal_samples and label == self.last_label:
continue
else:
break
self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode]
if 'in_map' in self.mode.lower():
map_array = self.map.as_array
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
label_as_array = np.full_like(map_array, label)
if self.mode == 'generator_all_in_map':
return np.concatenate((map_array, trajectory, label_as_array)), alternative
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
return np.sum((map_array, trajectory, alternative), axis=0), 0
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
return np.concatenate((map_array, trajectory)), alternative
elif self.mode == 'classifier_all_in_map':
return np.concatenate((map_array, trajectory, alternative)), label
elif self.mode == '_vectors':
raise NotImplementedError
return trajectory.vertices, alternative.vertices, label, self.mapname
raise ValueError(f'Mode was: {self.mode}')
def seed(self, seed):
self.map.seed(seed)
class TrajData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self.train_dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if '_all_in_map' in self.mode and not self.preprocessed:
shape_list[0] += 2
return shape_list
@property
def name(self):
return self.__class__.__name__
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
self.preprocessed = preprocessed
self.normalized = normalized
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self.test_dataset = self._load_datasets('test')
self.val_dataset = self._load_datasets('val')
self.train_dataset = self._load_datasets('train')
def _load_datasets(self, dataset_type=''):
map_files = list(self.maps_root.glob('*.bmp'))
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
if self.preprocessed:
preprocessed_map_files = list(self.maps_root.glob('*.pik'))
preprocessed_map_names = [p.name for p in preprocessed_map_files]
datasets = []
for map_file in map_files:
equal_split = int(self.length // len(map_files)) or 5
new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik'
if dataset_type != 'train':
equal_split = max(int(equal_split * 0.01), 10)
if not new_pik_name in preprocessed_map_names:
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split)
dataset = TrajDataShelve(map_file.parent / new_pik_name)
datasets.append(dataset)
return ConcatDataset(datasets)
# Set the equal split so that all maps are visited with the same frequency
equal_split = int(self.length // len(map_files)) or 5
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
for map_file in map_files])
def kill_em_all(self):
for pik_file in self.maps_root.glob('*.pik'):
pik_file.unlink()
print(pik_file.name, ' was deleted.')
print('Done.')
def seed(self, seed):
for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]:
dataset.seed(seed)
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
assert str(file_path).endswith('.pik')
mutex = mp.Lock()
for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'):
sample = traj_dataset[i]
mutex.acquire()
write_to_shelve(file_path, sample)
mutex.release()
print(f'{n} samples successfully dumped to "{file_path}"!')
def get_datasets(self):
return self._train_dataset, self._val_dataset, self._test_dataset