fig clf inserted and not resize on kld
This commit is contained in:
29
datasets/mnist.py
Normal file
29
datasets/mnist.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from torchvision.datasets import MNIST
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyMNIST(MNIST):
|
||||
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
return np.asarray(self.test_dataset[0][0]).shape
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MyMNIST, self).__init__('res', train=False, download=True)
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
image = super(MyMNIST, self).__getitem__(item)
|
||||
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self.__class__('res', train=True, download=True)
|
||||
|
||||
@property
|
||||
def test_dataset(self):
|
||||
return self.__class__('res', train=False, download=True)
|
||||
|
||||
@property
|
||||
def val_dataset(self):
|
||||
return self.__class__('res', train=False, download=True)
|
||||
@@ -1,6 +1,9 @@
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
from typing import Union
|
||||
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
import multiprocessing as mp
|
||||
|
||||
@@ -24,16 +27,17 @@ class TrajDataShelve(Dataset):
|
||||
return self[0][0].shape
|
||||
|
||||
def __init__(self, file_path, **kwargs):
|
||||
assert Path(file_path).exists()
|
||||
super(TrajDataShelve, self).__init__()
|
||||
self._mutex = mp.Lock()
|
||||
self.file_path = str(file_path)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
self._mutex.acquire()
|
||||
with shelve.open(self.file_path) as d:
|
||||
length = len(d)
|
||||
self._mutex.release()
|
||||
d.close()
|
||||
self._mutex.release()
|
||||
return length
|
||||
|
||||
def seed(self):
|
||||
@@ -43,12 +47,20 @@ class TrajDataShelve(Dataset):
|
||||
self._mutex.acquire()
|
||||
with shelve.open(self.file_path) as d:
|
||||
sample = d[str(item)]
|
||||
self._mutex.release()
|
||||
d.close()
|
||||
self._mutex.release()
|
||||
return sample
|
||||
|
||||
|
||||
class TrajDataset(Dataset):
|
||||
|
||||
@property
|
||||
def _last_label_init(self):
|
||||
d = defaultdict(lambda: -1)
|
||||
d['generator_hom_all_in_map'] = V.ALTERNATIVE
|
||||
d['generator_alt_all_in_map'] = V.HOMOTOPIC
|
||||
return d[self.mode]
|
||||
|
||||
@property
|
||||
def map_shape(self):
|
||||
return self.map.as_array.shape
|
||||
@@ -57,17 +69,18 @@ class TrajDataset(Dataset):
|
||||
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
|
||||
**kwargs):
|
||||
super(TrajDataset, self).__init__()
|
||||
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
|
||||
'classifier_all_in_map']
|
||||
self.normalized = normalized
|
||||
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map',
|
||||
'ae_no_label_in_map',
|
||||
'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map']
|
||||
self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x
|
||||
self.preserve_equal_samples = preserve_equal_samples
|
||||
self.mode = mode
|
||||
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
|
||||
self.maps_root = maps_root
|
||||
self._len = length
|
||||
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
|
||||
self.last_label = self._last_label_init
|
||||
|
||||
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||
self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
@@ -82,6 +95,7 @@ class TrajDataset(Dataset):
|
||||
map_array = torch.as_tensor(self.map.as_array).float()
|
||||
return (map_array, trajectory_space), label
|
||||
|
||||
# Produce an alternative.
|
||||
while True:
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternative = self.map.generate_alternative(trajectory)
|
||||
@@ -91,18 +105,19 @@ class TrajDataset(Dataset):
|
||||
else:
|
||||
break
|
||||
|
||||
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
|
||||
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
|
||||
self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode]
|
||||
if 'in_map' in self.mode.lower():
|
||||
map_array = self.map.as_array
|
||||
trajectory = trajectory.draw_in_array(self.map_shape)
|
||||
alternative = alternative.draw_in_array(self.map_shape)
|
||||
label_as_array = np.full_like(map_array, label)
|
||||
if self.normalized:
|
||||
map_array = map_array / V.WHITE
|
||||
trajectory = trajectory / V.WHITE
|
||||
alternative = alternative / V.WHITE
|
||||
|
||||
if self.mode == 'generator_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, label_as_array)), alternative
|
||||
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
|
||||
return np.sum((map_array, trajectory, alternative), axis=0), 0
|
||||
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
|
||||
return np.concatenate((map_array, trajectory)), alternative
|
||||
elif self.mode == 'classifier_all_in_map':
|
||||
return np.concatenate((map_array, trajectory, alternative)), label
|
||||
|
||||
@@ -119,13 +134,13 @@ class TrajDataset(Dataset):
|
||||
class TrajData(object):
|
||||
@property
|
||||
def map_shapes(self):
|
||||
return [dataset.map_shape for dataset in self._train_dataset.datasets]
|
||||
return [dataset.map_shape for dataset in self.train_dataset.datasets]
|
||||
|
||||
@property
|
||||
def map_shapes_max(self):
|
||||
shapes = self.map_shapes
|
||||
shape_list = list(map(max, zip(*shapes)))
|
||||
if '_all_in_map' in self.mode:
|
||||
if '_all_in_map' in self.mode and not self.preprocessed:
|
||||
shape_list[0] += 2
|
||||
return shape_list
|
||||
|
||||
@@ -139,14 +154,13 @@ class TrajData(object):
|
||||
self.mode = mode
|
||||
self.maps_root = Path(map_root)
|
||||
self.length = length
|
||||
self._test_dataset = self._load_datasets('train')
|
||||
self._val_dataset = self._load_datasets('val')
|
||||
self._train_dataset = self._load_datasets('test')
|
||||
self.test_dataset = self._load_datasets('test')
|
||||
self.val_dataset = self._load_datasets('val')
|
||||
self.train_dataset = self._load_datasets('train')
|
||||
|
||||
def _load_datasets(self, dataset_type=''):
|
||||
|
||||
map_files = list(self.maps_root.glob('*.bmp'))
|
||||
equal_split = int(self.length // len(map_files)) or 1
|
||||
|
||||
# find max image size among available maps:
|
||||
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
|
||||
@@ -156,10 +170,11 @@ class TrajData(object):
|
||||
preprocessed_map_names = [p.name for p in preprocessed_map_files]
|
||||
datasets = []
|
||||
for map_file in map_files:
|
||||
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
|
||||
equal_split = int(self.length // len(map_files)) or 5
|
||||
new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik'
|
||||
if dataset_type != 'train':
|
||||
equal_split *= 0.01
|
||||
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
|
||||
equal_split = max(int(equal_split * 0.01), 10)
|
||||
if not new_pik_name in preprocessed_map_names:
|
||||
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||
preserve_equal_samples=True)
|
||||
@@ -168,6 +183,9 @@ class TrajData(object):
|
||||
dataset = TrajDataShelve(map_file.parent / new_pik_name)
|
||||
datasets.append(dataset)
|
||||
return ConcatDataset(datasets)
|
||||
|
||||
# Set the equal split so that all maps are visited with the same frequency
|
||||
equal_split = int(self.length // len(map_files)) or 5
|
||||
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
|
||||
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
|
||||
preserve_equal_samples=True)
|
||||
@@ -185,29 +203,14 @@ class TrajData(object):
|
||||
|
||||
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
|
||||
assert str(file_path).endswith('.pik')
|
||||
processes = mp.cpu_count() - 1
|
||||
mutex = mp.Lock()
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
|
||||
for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'):
|
||||
sample = traj_dataset[i]
|
||||
mutex.acquire()
|
||||
write_to_shelve(file_path, sample)
|
||||
mutex.release()
|
||||
|
||||
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
|
||||
sample = result_obj.get()
|
||||
mutex.acquire()
|
||||
write_to_shelve(file_path, sample)
|
||||
mutex.release()
|
||||
print(f'{n} samples sucessfully dumped to "{file_path}"!')
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@property
|
||||
def val_dataset(self):
|
||||
return self._val_dataset
|
||||
|
||||
@property
|
||||
def test_dataset(self):
|
||||
return self._test_dataset
|
||||
print(f'{n} samples successfully dumped to "{file_path}"!')
|
||||
|
||||
def get_datasets(self):
|
||||
return self._train_dataset, self._val_dataset, self._test_dataset
|
||||
|
||||
Reference in New Issue
Block a user