2020-02-17 16:41:14 +01:00

128 lines
5.3 KiB
Python

import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Union
from tqdm import trange
from lib.objects.map import Map
from lib.utils.parallel import run_n_in_parallel
class Generator:
possible_modes = ['one_patching']
def __init__(self, data_root, map_obj, binary=True):
self.binary: bool = binary
self.map: Map = map_obj
self.data_root = Path(data_root)
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
kwargs.update(datafile_name=datafile_name, n=m)
processes = processes if processes else mp.cpu_count()
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)]
# for _ in trange(n, desc='Processing Trajectories'):
# self.write_n_alternatives(m, dataset_name, **kwargs)
# This line is for error catching only
results = [r.get() for r in async_results]
with shelve.open(str(self.data_root / datafile_name)) as f:
for datafile in self.data_root.glob(f'datafile_name*'):
with shelve.open(str(datafile)) as sub_f:
for key in sub_f.keys():
f[len(f)] = sub_f[key]
datafile.unlink()
pass
def generate_alternatives(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.map.get_valid_position()
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
return alternative
def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None,
mode='one_patching', equal_samples=True, binary_check=True):
assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}'
assert n is not None, f'n is not allowed to be None but was: {n}'
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
# label per homotopic class
homotopy_classes = defaultdict(list)
homotopy_classes[0].append(trajectory)
for i in range(len(results)):
alternative = results[i]
class_not_found = True
# check for homotopy class
for label in homotopy_classes.keys():
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
homotopy_classes[label].append(alternative)
class_not_found = False
break
if class_not_found:
label = 1 if binary_check else len(homotopy_classes)
homotopy_classes[label].append(alternative)
# There should be as much homotopic samples as non-homotopic samples
if equal_samples:
homotopy_classes = self._remove_unequal(homotopy_classes)
if not homotopy_classes:
return None, None
# Compose lists of alternatives with labels
alternatives, labels = list(), list()
for key in homotopy_classes.keys():
alternatives.extend(homotopy_classes[key])
labels.extend([key] * len(homotopy_classes[key]))
# Write to disk
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
return True
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
self.data_root.mkdir(exist_ok=True, parents=True)
with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = len(f)
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
trajectory=trajectory,
labels=labels)
if 'map' not in f:
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
@staticmethod
def _remove_unequal(hom_dict):
# We argue, that there will always be more non-homotopic routes than homotopic alternatives.
# TODO: Otherwise introduce a second condition / loop
hom_dict = hom_dict.copy()
if len(hom_dict[0]) <= 1:
return None
counter = len(hom_dict)
while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]):
if counter == 0:
counter = len(hom_dict)
if counter in hom_dict:
if len(hom_dict[counter]) == 0:
del hom_dict[counter]
else:
del hom_dict[counter][-1]
counter -= 1
return hom_dict