fig clf inserted and not resize on kld

This commit is contained in:
Steffen Illium
2020-03-13 21:52:33 +01:00
parent bb47e07566
commit 2305c8e54a
33 changed files with 403 additions and 279 deletions

29
datasets/mnist.py Normal file
View File

@@ -0,0 +1,29 @@
from torchvision.datasets import MNIST
import numpy as np
class MyMNIST(MNIST):
@property
def map_shapes_max(self):
return np.asarray(self.test_dataset[0][0]).shape
def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__('res', train=False, download=True)
pass
def __getitem__(self, item):
image = super(MyMNIST, self).__getitem__(item)
return np.expand_dims(np.asarray(image[0]), axis=0).astype(np.float32), image[1]
@property
def train_dataset(self):
return self.__class__('res', train=True, download=True)
@property
def test_dataset(self):
return self.__class__('res', train=False, download=True)
@property
def val_dataset(self):
return self.__class__('res', train=False, download=True)

View File

@@ -1,6 +1,9 @@
import shelve
from collections import defaultdict
from pathlib import Path
from typing import Union, List
from typing import Union
from torchvision.transforms import Normalize
import multiprocessing as mp
@@ -24,16 +27,17 @@ class TrajDataShelve(Dataset):
return self[0][0].shape
def __init__(self, file_path, **kwargs):
assert Path(file_path).exists()
super(TrajDataShelve, self).__init__()
self._mutex = mp.Lock()
self.file_path = str(file_path)
def __len__(self):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
length = len(d)
self._mutex.release()
d.close()
self._mutex.release()
return length
def seed(self):
@@ -43,12 +47,20 @@ class TrajDataShelve(Dataset):
self._mutex.acquire()
with shelve.open(self.file_path) as d:
sample = d[str(item)]
self._mutex.release()
d.close()
self._mutex.release()
return sample
class TrajDataset(Dataset):
@property
def _last_label_init(self):
d = defaultdict(lambda: -1)
d['generator_hom_all_in_map'] = V.ALTERNATIVE
d['generator_alt_all_in_map'] = V.HOMOTOPIC
return d[self.mode]
@property
def map_shape(self):
return self.map.as_array.shape
@@ -57,17 +69,18 @@ class TrajDataset(Dataset):
length=100000, mode='separated_arrays', embedding_size=None, preserve_equal_samples=False,
**kwargs):
super(TrajDataset, self).__init__()
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map'
'classifier_all_in_map']
self.normalized = normalized
assert mode.lower() in ['generator_all_in_map', 'generator_hom_all_in_map', 'generator_alt_all_in_map',
'ae_no_label_in_map',
'generator_alt_no_label_in_map', 'classifier_all_in_map', 'vae_no_label_in_map']
self.normalize = Normalize(0.5, 0.5) if normalized else lambda x: x
self.preserve_equal_samples = preserve_equal_samples
self.mode = mode
self.mapname = mapname if mapname.endswith('.bmp') else f'{mapname}.bmp'
self.maps_root = maps_root
self._len = length
self.last_label = V.ALTERNATIVE if 'hom' in self.mode else choice([-1, V.ALTERNATIVE, V.HOMOTOPIC])
self.last_label = self._last_label_init
self.map = Map(self.mapname).from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
self.map = Map.from_image(self.maps_root / self.mapname, embedding_size=embedding_size)
def __len__(self):
return self._len
@@ -82,6 +95,7 @@ class TrajDataset(Dataset):
map_array = torch.as_tensor(self.map.as_array).float()
return (map_array, trajectory_space), label
# Produce an alternative.
while True:
trajectory = self.map.get_random_trajectory()
alternative = self.map.generate_alternative(trajectory)
@@ -91,18 +105,19 @@ class TrajDataset(Dataset):
else:
break
self.last_label = label if self.mode != ['generator_hom_all_in_map'] else V.ALTERNATIVE
if self.mode.lower() in ['classifier_all_in_map', 'generator_all_in_map']:
self.last_label = label if self._last_label_init == V.ANY else self._last_label_init[self.mode]
if 'in_map' in self.mode.lower():
map_array = self.map.as_array
trajectory = trajectory.draw_in_array(self.map_shape)
alternative = alternative.draw_in_array(self.map_shape)
label_as_array = np.full_like(map_array, label)
if self.normalized:
map_array = map_array / V.WHITE
trajectory = trajectory / V.WHITE
alternative = alternative / V.WHITE
if self.mode == 'generator_all_in_map':
return np.concatenate((map_array, trajectory, label_as_array)), alternative
elif self.mode in ['vae_no_label_in_map', 'ae_no_label_in_map']:
return np.sum((map_array, trajectory, alternative), axis=0), 0
elif self.mode in ['generator_alt_no_label_in_map', 'generator_hom_no_label_in_map']:
return np.concatenate((map_array, trajectory)), alternative
elif self.mode == 'classifier_all_in_map':
return np.concatenate((map_array, trajectory, alternative)), label
@@ -119,13 +134,13 @@ class TrajDataset(Dataset):
class TrajData(object):
@property
def map_shapes(self):
return [dataset.map_shape for dataset in self._train_dataset.datasets]
return [dataset.map_shape for dataset in self.train_dataset.datasets]
@property
def map_shapes_max(self):
shapes = self.map_shapes
shape_list = list(map(max, zip(*shapes)))
if '_all_in_map' in self.mode:
if '_all_in_map' in self.mode and not self.preprocessed:
shape_list[0] += 2
return shape_list
@@ -139,14 +154,13 @@ class TrajData(object):
self.mode = mode
self.maps_root = Path(map_root)
self.length = length
self._test_dataset = self._load_datasets('train')
self._val_dataset = self._load_datasets('val')
self._train_dataset = self._load_datasets('test')
self.test_dataset = self._load_datasets('test')
self.val_dataset = self._load_datasets('val')
self.train_dataset = self._load_datasets('train')
def _load_datasets(self, dataset_type=''):
map_files = list(self.maps_root.glob('*.bmp'))
equal_split = int(self.length // len(map_files)) or 1
# find max image size among available maps:
max_map_size = (1, ) + tuple(reversed(tuple(map(max, *[Image.open(map_file).size for map_file in map_files]))))
@@ -156,10 +170,11 @@ class TrajData(object):
preprocessed_map_names = [p.name for p in preprocessed_map_files]
datasets = []
for map_file in map_files:
new_pik_name = f'{dataset_type}_{str(map_file.name)[:-3]}.pik'
equal_split = int(self.length // len(map_files)) or 5
new_pik_name = f'{self.mode}_{map_file.name[:-4]}_{dataset_type}.pik'
if dataset_type != 'train':
equal_split *= 0.01
if not [f'{new_pik_name[:-3]}.bmp' in preprocessed_map_names]:
equal_split = max(int(equal_split * 0.01), 10)
if not new_pik_name in preprocessed_map_names:
traj_dataset = TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
@@ -168,6 +183,9 @@ class TrajData(object):
dataset = TrajDataShelve(map_file.parent / new_pik_name)
datasets.append(dataset)
return ConcatDataset(datasets)
# Set the equal split so that all maps are visited with the same frequency
equal_split = int(self.length // len(map_files)) or 5
return ConcatDataset([TrajDataset(maps_root=self.maps_root, mapname=map_file.name, length=equal_split,
mode=self.mode, embedding_size=max_map_size, normalized=self.normalized,
preserve_equal_samples=True)
@@ -185,29 +203,14 @@ class TrajData(object):
def dump_n(self, file_path, traj_dataset: TrajDataset, n=100000):
assert str(file_path).endswith('.pik')
processes = mp.cpu_count() - 1
mutex = mp.Lock()
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(traj_dataset.__getitem__, kwds=dict(item=i)) for i in range(n)]
for i in tqdm(range(n), total=n, desc=f'Generating {n} Samples'):
sample = traj_dataset[i]
mutex.acquire()
write_to_shelve(file_path, sample)
mutex.release()
for result_obj in tqdm(async_results, total=n, desc=f'Generating {n} Samples'):
sample = result_obj.get()
mutex.acquire()
write_to_shelve(file_path, sample)
mutex.release()
print(f'{n} samples sucessfully dumped to "{file_path}"!')
@property
def train_dataset(self):
return self._train_dataset
@property
def val_dataset(self):
return self._val_dataset
@property
def test_dataset(self):
return self._test_dataset
print(f'{n} samples successfully dumped to "{file_path}"!')
def get_datasets(self):
return self._train_dataset, self._val_dataset, self._test_dataset