217 lines
8.2 KiB
Python
217 lines
8.2 KiB
Python
import shelve
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
from torchvision.transforms import Normalize
|
|
|
|
import multiprocessing as mp
|
|
|
|
import torch
|
|
from random import choice
|
|
from torch.utils.data import ConcatDataset, Dataset
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from lib.objects.map import Map
|
|
import lib.variables as V
|
|
from PIL import Image
|
|
|
|
from lib.utils.tools import write_to_shelve
|
|
|
|
|
|
class TrajDataShelve(Dataset):
|
|
|
|
@property
|
|
def map_shape(self):
|
|
return self[0][0].shape
|
|
|
|
def __init__(self, file_path, **kwargs):
|
|
assert Path(file_path).exists()
|
|
super(TrajDataShelve, self).__init__()
|
|
self._mutex = mp.Lock()
|
|
self.file_path = str(file_path)
|
|
|
|
def __len__(self):
|
|
self._mutex.acquire()
|
|
with shelve.open(self.file_path) as d:
|
|
length = len(d)
|
|
d.close()
|
|
self._mutex.release()
|
|
return length
|
|
|
|
def seed(self):
|
|
pass
|
|
|
|
def __getitem__(self, item):
|
|
self._mutex.acquire()
|
|
with shelve.open(self.file_path) as d:
|
|
sample = d[str(item)]
|
|
d.close()
|
|
self._mutex.release()
|
|
return sample
|
|
|
|
|
|
class TrajDataset(Dataset):
|
|
|
|
@property
|
|
def _last_label_init(self):
|
|
d = defaultdict(lambda: -1)
|
|
d['generator_hom_all_in_map'] = V.ALTERNATIVE
|
|
d['generator_alt_all_in_map'] = V.HOMOTOPIC
|
|
return d[self.mode]
|
|
|
|
@property
|
|
def map_shape(self):
|
|
return self.map.as_array.shape
|
|
|
|
def __init__(self, *args, maps_root: Union[Path, str] = '', mapname='tate_sw', normalized=True,
|
|
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
|
**kwargs):
|
|
super(TrajDataset, self).__init__()
|
|
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map',
|
|
'ae_no_label_in_map',
|
|
'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map']
|
|
self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x
|
|
self.preserve_equal_samples = preserve_equal_samples
|
|
self.mode = mode
|
|
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
|
self.maps_root = maps_root
|
|
self._len = length
|
|
self.last_label = self._last_label_init
|
|
|
|
self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
|
|
|
def __len__(self):
|
|
return self._len
|
|
|
|
def __getitem__(self, item):
|
|
|
|
if self.mode.lower() == 'just_route':
|
|
raise NotImplementedError
|
|
trajectory = self.map.get_random_trajectory()
|
|
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
|
label = choice([0, 1])
|
|
map_array = torch.as_tensor(self.map.as_array).float()
|
|
return (map_array, trajectory_space), label
|
|
|
|
# Produce an alternative.
|
|
while True:
|
|
trajectory = self.map.get_random_trajectory()
|
|
alternative = self.map.generate_alternative(trajectory)
|
|
label = self.map.are_homotopic(trajectory, alternative)
|
|
if self.preserve_equal_samples and label == self.last_label:
|
|
continue
|
|
else:
|
|
break
|
|
|
|
self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode]
|
|
if 'in_map' in self.mode.lower():
|
|
map_array = self.map.as_array
|
|
trajectory = trajectory.draw_in_array(self.map_shape)
|
|
alternative = alternative.draw_in_array(self.map_shape)
|
|
label_as_array = np.full_like(map_array, label)
|
|
|
|
if self.mode == 'generator_all_in_map':
|
|
return np.concatenate((map_array, trajectory, label_as_array)), alternative
|
|
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
|
|
return np.sum((map_array, trajectory, alternative), axis=0), 0
|
|
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
|
|
return np.concatenate((map_array, trajectory)), alternative
|
|
elif self.mode == 'classifier_all_in_map':
|
|
return np.concatenate((map_array, trajectory, alternative)), label
|
|
|
|
elif self.mode == '_vectors':
|
|
raise NotImplementedError
|
|
return trajectory.vertices, alternative.vertices, label, self.mapname
|
|
|
|
raise ValueError(f'Mode was: {self.mode}')
|
|
|
|
def seed(self, seed):
|
|
self.map.seed(seed)
|
|
|
|
|
|
class TrajData(object):
|
|
@property
|
|
def map_shapes(self):
|
|
return [dataset.map_shape for dataset in self.train_dataset.datasets]
|
|
|
|
@property
|
|
def map_shapes_max(self):
|
|
shapes = self.map_shapes
|
|
shape_list = list(map(max, zip(*shapes)))
|
|
if '_all_in_map' in self.mode and not self.preprocessed:
|
|
shape_list[0] += 2
|
|
return shape_list
|
|
|
|
@property
|
|
def name(self):
|
|
return self.__class__.__name__
|
|
|
|
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
|
|
self.preprocessed = preprocessed
|
|
self.normalized = normalized
|
|
self.mode = mode
|
|
self.maps_root = Path(map_root)
|
|
self.length = length
|
|
self.test_dataset = self._load_datasets('test')
|
|
self.val_dataset = self._load_datasets('val')
|
|
self.train_dataset = self._load_datasets('train')
|
|
|
|
def _load_datasets(self, dataset_type=''):
|
|
|
|
map_files = list(self.maps_root.glob('*.bmp'))
|
|
|
|
# find max image size among available maps:
|
|
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
|
|
|
if self.preprocessed:
|
|
preprocessed_map_files = list(self.maps_root.glob('*.pik'))
|
|
preprocessed_map_names = [p.name for p in preprocessed_map_files]
|
|
datasets = []
|
|
for map_file in map_files:
|
|
equal_split = int(self.length // len(map_files)) or 5
|
|
new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik'
|
|
if dataset_type != 'train':
|
|
equal_split = max(int(equal_split * 0.01), 10)
|
|
if not new_pik_name in preprocessed_map_names:
|
|
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
|
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
|
preserve_equal_samples=True)
|
|
self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split)
|
|
|
|
dataset = TrajDataShelve(map_file.parent / new_pik_name)
|
|
datasets.append(dataset)
|
|
return ConcatDataset(datasets)
|
|
|
|
# Set the equal split so that all maps are visited with the same frequency
|
|
equal_split = int(self.length // len(map_files)) or 5
|
|
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
|
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
|
preserve_equal_samples=True)
|
|
for map_file in map_files])
|
|
|
|
def kill_em_all(self):
|
|
for pik_file in self.maps_root.glob('*.pik'):
|
|
pik_file.unlink()
|
|
print(pik_file.name, ' was deleted.')
|
|
print('Done.')
|
|
|
|
def seed(self, seed):
|
|
for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]:
|
|
dataset.seed(seed)
|
|
|
|
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
|
|
assert str(file_path).endswith('.pik')
|
|
mutex = mp.Lock()
|
|
for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'):
|
|
sample = traj_dataset[i]
|
|
mutex.acquire()
|
|
write_to_shelve(file_path, sample)
|
|
mutex.release()
|
|
|
|
print(f'{n} samples successfully dumped to "{file_path}"!')
|
|
|
|
def get_datasets(self):
|
|
return self._train_dataset, self._val_dataset, self._test_dataset
|