Offline Datasets res net optionality
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
|
||||
from datasets.utils import DatasetMapping
|
||||
from lib.modules.model_parts import Generator
|
||||
from lib.preprocessing.generator import Generator
|
||||
from lib.objects.map import Map
|
||||
|
||||
|
||||
|
||||
@@ -2,15 +2,50 @@ import shelve
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
|
||||
import multiprocessing as mp
|
||||
|
||||
import torch
|
||||
from random import choice
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lib.objects.map import Map
|
||||
import lib.variables as V
|
||||
from PIL import Image
|
||||
|
||||
from lib.utils.tools import write_to_shelve
|
||||
|
||||
|
||||
class TrajDataShelve(Dataset):
|
||||
|
||||
@property
|
||||
def map_shape(self):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, file_path, **kwargs):
|
||||
super(TrajDataShelve, self).__init__()
|
||||
self._mutex = mp.Lock()
|
||||
self.file_path = str(file_path)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
self._mutex.acquire()
|
||||
with shelve.open(self.file_path) as d:
|
||||
length = len(d)
|
||||
self._mutex.release()
|
||||
return length
|
||||
|
||||
def seed(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
self._mutex.acquire()
|
||||
with shelve.open(self.file_path) as d:
|
||||
sample = d[str(item)]
|
||||
self._mutex.release()
|
||||
return sample
|
||||
|
||||
|
||||
class TrajDataset(Dataset):
|
||||
|
||||
@@ -22,14 +57,15 @@ class TrajDataset(Dataset):
|
||||
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
||||
**kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
assert mode.lower() in ['vectors', 'all_in_map', 'separated_arrays', 'just_route']
|
||||
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
|
||||
'classifier_all_in_map']
|
||||
self.normalized = normalized
|
||||
self.preserve_equal_samples = preserve_equal_samples
|
||||
self.mode = mode
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
self.maps_root = maps_root
|
||||
self._len = length
|
||||
self.last_label = -1
|
||||
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
|
||||
|
||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||
|
||||
@@ -39,6 +75,7 @@ class TrajDataset(Dataset):
|
||||
def __getitem__(self, item):
|
||||
|
||||
if self.mode.lower() == 'just_route':
|
||||
raise NotImplementedError
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
trajectory_space = trajectory.draw_in_array(self.map.shape)
|
||||
label = choice([0, 1])
|
||||
@@ -54,37 +91,41 @@ class TrajDataset(Dataset):
|
||||
else:
|
||||
break
|
||||
|
||||
self.last_label = label
|
||||
if self.mode.lower() in ['all_in_map', 'separated_arrays']:
|
||||
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
|
||||
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
|
||||
map_array = self.map.as_array
|
||||
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||
alternative = alternative.draw_in_array(self.map_shape)
|
||||
if self.mode == 'separated_arrays':
|
||||
if self.normalized:
|
||||
map_array = map_array / V.WHITE
|
||||
trajectory = trajectory / V.WHITE
|
||||
alternative = alternative / V.WHITE
|
||||
return (map_array, trajectory, label), alternative
|
||||
else:
|
||||
label_as_array = np.full_like(map_array, label)
|
||||
if self.normalized:
|
||||
map_array = map_array / V.WHITE
|
||||
trajectory = trajectory / V.WHITE
|
||||
alternative = alternative / V.WHITE
|
||||
if self.mode == 'generator_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, label_as_array)), alternative
|
||||
elif self.mode == 'classifier_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, alternative)), label
|
||||
|
||||
elif self.mode == 'vectors':
|
||||
elif self.mode == '_vectors':
|
||||
raise NotImplementedError
|
||||
return trajectory.vertices, alternative.vertices, label, self.mapname
|
||||
|
||||
else:
|
||||
raise ValueError
|
||||
raise ValueError(f'Mode was: {self.mode}')
|
||||
|
||||
def seed(self, seed):
|
||||
self.map.seed(seed)
|
||||
|
||||
|
||||
class TrajData(object):
|
||||
@property
|
||||
def map_shapes(self):
|
||||
return [dataset.map_shape for dataset in self._dataset.datasets]
|
||||
return [dataset.map_shape for dataset in self._train_dataset.datasets]
|
||||
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
shape_list = list(map(max, zip(*shapes)))
|
||||
if self.mode in ['separated_arrays', 'all_in_map']:
|
||||
if '_all_in_map' in self.mode:
|
||||
shape_list[0] += 2
|
||||
return shape_list
|
||||
|
||||
@@ -92,36 +133,81 @@ class TrajData(object):
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, **_):
|
||||
|
||||
def __init__(self, map_root, length=100000, mode='separated_arrays', normalized=True, preprocessed=False, **_):
|
||||
self.preprocessed = preprocessed
|
||||
self.normalized = normalized
|
||||
self.mode = mode
|
||||
self.maps_root = Path(map_root)
|
||||
self.length = length
|
||||
self._dataset = self._load_datasets()
|
||||
self._test_dataset = self._load_datasets('train')
|
||||
self._val_dataset = self._load_datasets('val')
|
||||
self._train_dataset = self._load_datasets('test')
|
||||
|
||||
def _load_datasets(self, dataset_type=''):
|
||||
|
||||
def _load_datasets(self):
|
||||
map_files = list(self.maps_root.glob('*.bmp'))
|
||||
equal_split = int(self.length // len(map_files)) or 1
|
||||
|
||||
# find max image size among available maps:
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
|
||||
if self.preprocessed:
|
||||
preprocessed_map_files = list(self.maps_root.glob('*.pik'))
|
||||
preprocessed_map_names = [p.name for p in preprocessed_map_files]
|
||||
datasets = []
|
||||
for map_file in map_files:
|
||||
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
|
||||
if dataset_type != 'train':
|
||||
equal_split *= 0.01
|
||||
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
|
||||
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||
preserve_equal_samples=True)
|
||||
self.dump_n(map_file.parent / new_pik_name, traj_dataset, n=equal_split)
|
||||
|
||||
dataset = TrajDataShelve(map_file.parent / new_pik_name)
|
||||
datasets.append(dataset)
|
||||
return ConcatDataset(datasets)
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||
preserve_equal_samples=True)
|
||||
for map_file in map_files])
|
||||
|
||||
def kill_em_all(self):
|
||||
for pik_file in self.maps_root.glob('*.pik'):
|
||||
pik_file.unlink()
|
||||
print(pik_file.name, ' was deleted.')
|
||||
print('Done.')
|
||||
|
||||
def seed(self, seed):
|
||||
for dataset in [x.datasets for x in [self._train_dataset, self._test_dataset, self.val_dataset]]:
|
||||
dataset.seed(seed)
|
||||
|
||||
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
|
||||
assert str(file_path).endswith('.pik')
|
||||
processes = mp.cpu_count() - 1
|
||||
mutex = mp.Lock()
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
|
||||
|
||||
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
|
||||
sample = result_obj.get()
|
||||
mutex.acquire()
|
||||
write_to_shelve(file_path, sample)
|
||||
mutex.release()
|
||||
print(f'{n} samples sucessfully dumped to "{file_path}"!')
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._dataset
|
||||
return self._train_dataset
|
||||
|
||||
@property
|
||||
def val_dataset(self):
|
||||
return self._dataset
|
||||
return self._val_dataset
|
||||
|
||||
@property
|
||||
def test_dataset(self):
|
||||
return self._dataset
|
||||
return self._test_dataset
|
||||
|
||||
def get_datasets(self):
|
||||
return self._dataset, self._dataset, self._dataset
|
||||
return self._train_dataset, self._val_dataset, self._test_dataset
|
||||
|
||||
Reference in New Issue
Block a user