2020-02-21 09:44:09 +01:00

116 lines
4.8 KiB
Python

import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm
from lib.objects.map import Map
class Generator:
possible_modes = ['one_patching']
def __init__(self, data_root, map_obj, binary=True):
self.binary: bool = binary
self.map: Map = map_obj
self.data_root = Path(data_root)
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
kwargs.update(n=m)
processes = processes if processes else mp.cpu_count() - 1
mutex = mp.Lock()
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
for result_obj in tqdm(async_results, total=n, desc='Producing trajectories with Alternatives'):
trajectory, alternatives, labels = result_obj.get()
mutex.acquire()
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
mutex.release()
with shelve.open(str(self.data_root / datafile_name)) as f:
for datafile in self.data_root.glob(f'datafile_name*'):
with shelve.open(str(datafile)) as sub_f:
for key in sub_f.keys():
f[len(f)] = sub_f[key]
datafile.unlink()
pass
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None, is_sub_process=False,
mode='one_patching', equal_samples=True, binary_check=True):
assert n is not None, f'n is not allowed to be None but was: {n}'
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
results = [self.map.generate_alternative(trajectory=trajectory, mode=mode) for _ in range(n)]
# label per homotopic class
homotopy_classes = defaultdict(list)
homotopy_classes[0].append(trajectory)
for i in range(len(results)):
alternative = results[i]
class_not_found = True
# check for homotopy class
for label in homotopy_classes.keys():
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
homotopy_classes[label].append(alternative)
class_not_found = False
break
if class_not_found:
label = 1 if binary_check else len(homotopy_classes)
homotopy_classes[label].append(alternative)
# There should be as much homotopic samples as non-homotopic samples
if equal_samples:
homotopy_classes = self._remove_unequal(homotopy_classes)
if not homotopy_classes:
return None, None, None
# Compose lists of alternatives with labels
alternatives, labels = list(), list()
for key in homotopy_classes.keys():
alternatives.extend(homotopy_classes[key])
labels.extend([key] * len(homotopy_classes[key]))
if datafile_name:
if is_sub_process:
datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
# Write to disk
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
return trajectory, alternatives, labels
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
self.data_root.mkdir(exist_ok=True, parents=True)
with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = len(f)
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
trajectory=trajectory,
labels=labels)
if 'map' not in f:
f['map'] = dict(map=self.map, name=self.map.name)
@staticmethod
def _remove_unequal(hom_dict):
# We argue, that there will always be more non-homotopic routes than homotopic alternatives.
# TODO: Otherwise introduce a second condition / loop
hom_dict = hom_dict.copy()
if len(hom_dict[0]) <= 1:
return None
counter = len(hom_dict)
while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]):
if counter == 0:
counter = len(hom_dict)
if counter in hom_dict:
if len(hom_dict[counter]) == 0:
del hom_dict[counter]
else:
del hom_dict[counter][-1]
counter -= 1
return hom_dict