Offline Datasets res net optionality

This commit is contained in:
Si11ium
2020-03-12 18:32:23 +01:00
parent 2f99341cc3
commit bb47e07566
11 changed files with 638 additions and 140 deletions

View File

@@ -6,7 +6,7 @@ import torch
from torch.utils.data import Dataset, ConcatDataset
from datasets.utils import DatasetMapping
from lib.modules.model_parts import Generator
from lib.preprocessing.generator import Generator
from lib.objects.map import Map

View File

@@ -2,15 +2,50 @@ import shelve
from pathlib import Path
from typing import Union, List
import multiprocessing as mp
import torch
from random import choice
from torch.utils.data import ConcatDataset, Dataset
import numpy as np
from tqdm import tqdm
from lib.objects.map import Map
import lib.variables as V
from PIL import Image
from lib.utils.tools import write_to_shelve
class TrajDataShelve(Dataset):
@property
def map_shape(self):
return self[0][0].shape
def __init__(self, file_path, **kwargs):
super(TrajDataShelve, self).__init__()
self._mutex = mp.Lock()
self.file_path = str(file_path)
def __len__(self):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
length = len(d)
self._mutex.release()
return length
def seed(self):
pass
def __getitem__(self, item):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
sample = d[str(item)]
self._mutex.release()
return sample
class TrajDataset(Dataset):
@@ -22,14 +57,15 @@ class TrajDataset(Dataset):
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
**kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
'classifier_all_in_map']
self.normalized = normalized
self.preserve_equal_samples = preserve_equal_samples
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = -1
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
@@ -39,6 +75,7 @@ class TrajDataset(Dataset):
def __getitem__(self, item):
if self.mode.lower() == 'just_route':
raise NotImplementedError
trajectory = self.map.get_random_trajectory()
trajectory_space = trajectory.draw_in_array(self.map.shape)
label = choice([0, 1])
@@ -54,37 +91,41 @@ class TrajDataset(Dataset):
else:
break
self.last_label = label
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
map_array = self.map.as_array
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
if self.mode == 'separated_arrays':
if self.normalized:
map_array = map_array / V.WHITE
trajectory = trajectory / V.WHITE
alternative = alternative / V.WHITE
return (map_array, trajectory, label), alternative
else:
label_as_array = np.full_like(map_array, label)
if self.normalized:
map_array = map_array / V.WHITE
trajectory = trajectory / V.WHITE
alternative = alternative / V.WHITE
if self.mode == 'generator_all_in_map':
return np.concatenate((map_array, trajectory, label_as_array)), alternative
elif self.mode == 'classifier_all_in_map':
return np.concatenate((map_array, trajectory, alternative)), label
elif self.mode == 'vectors':
elif self.mode == '_vectors':
raise NotImplementedError
return trajectory.vertices, alternative.vertices, label, self.mapname
else:
raise ValueError
raise ValueError(f'Mode was: {self.mode}')
def seed(self, seed):
self.map.seed(seed)
class TrajData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._dataset.datasets]
return [dataset.map_shape for dataset in self._train_dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if self.mode in ['separated_arrays', 'all_in_map']:
if '_all_in_map' in self.mode:
shape_list[0] += 2
return shape_list
@@ -92,36 +133,81 @@ class TrajData(object):
def name(self):
return self.__class__.__name__
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
self.preprocessed = preprocessed
self.normalized = normalized
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._dataset = self._load_datasets()
self._test_dataset = self._load_datasets('train')
self._val_dataset = self._load_datasets('val')
self._train_dataset = self._load_datasets('test')
def _load_datasets(self, dataset_type=''):
def _load_datasets(self):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files)) or 1
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
if self.preprocessed:
preprocessed_map_files = list(self.maps_root.glob('*.pik'))
preprocessed_map_names = [p.name for p in preprocessed_map_files]
datasets = []
for map_file in map_files:
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
if dataset_type != 'train':
equal_split *= 0.01
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split)
dataset = TrajDataShelve(map_file.parent / new_pik_name)
datasets.append(dataset)
return ConcatDataset(datasets)
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
for map_file in map_files])
def kill_em_all(self):
for pik_file in self.maps_root.glob('*.pik'):
pik_file.unlink()
print(pik_file.name, ' was deleted.')
print('Done.')
def seed(self, seed):
for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]:
dataset.seed(seed)
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
assert str(file_path).endswith('.pik')
processes = mp.cpu_count() - 1
mutex = mp.Lock()
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
sample = result_obj.get()
mutex.acquire()
write_to_shelve(file_path, sample)
mutex.release()
print(f'{n} samples sucessfully dumped to "{file_path}"!')
@property
def train_dataset(self):
return self._dataset
return self._train_dataset
@property
def val_dataset(self):
return self._dataset
return self._val_dataset
@property
def test_dataset(self):
return self._dataset
return self._test_dataset
def get_datasets(self):
return self._dataset, self._dataset, self._dataset
return self._train_dataset, self._val_dataset, self._test_dataset