import multiprocessing as mp import pickle import shelve from collections import defaultdict from pathlib import Path from tqdm import tqdm from lib.objects.map import Map class Generator: possible_modes = ['one_patching'] def __init__(self, data_root, map_obj, binary=True): self.binary: bool = binary self.map: Map = map_obj self.data_root = Path(data_root) def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs): datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik' kwargs.update(n=m) processes = processes if processes else mp.cpu_count() - 1 mutex = mp.Lock() with mp.Pool(processes) as pool: async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)] for result_obj in tqdm(async_results, total=n, desc='Producing trajectories with Alternatives'): trajectory, alternatives, labels = result_obj.get() mutex.acquire() self.write_to_disk(datafile_name, trajectory, alternatives, labels) mutex.release() with shelve.open(str(self.data_root / datafile_name)) as f: for datafile in self.data_root.glob(f'datafile_name*'): with shelve.open(str(datafile)) as sub_f: for key in sub_f.keys(): f[len(f)] = sub_f[key] datafile.unlink() pass def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None, is_sub_process=False, mode='one_patching', equal_samples=True, binary_check=True): assert n is not None, f'n is not allowed to be None but was: {n}' assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.' trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory() results = [self.map.generate_alternative(trajectory=trajectory, mode=mode) for _ in range(n)] # label per homotopic class homotopy_classes = defaultdict(list) homotopy_classes[0].append(trajectory) for i in range(len(results)): alternative = results[i] class_not_found = True # check for homotopy class for label in homotopy_classes.keys(): if self.map.are_homotopic(homotopy_classes[label][0], alternative): homotopy_classes[label].append(alternative) class_not_found = False break if class_not_found: label = 1 if binary_check else len(homotopy_classes) homotopy_classes[label].append(alternative) # There should be as much homotopic samples as non-homotopic samples if equal_samples: homotopy_classes = self._remove_unequal(homotopy_classes) if not homotopy_classes: return None, None, None # Compose lists of alternatives with labels alternatives, labels = list(), list() for key in homotopy_classes.keys(): alternatives.extend(homotopy_classes[key]) labels.extend([key] * len(homotopy_classes[key])) if datafile_name: if is_sub_process: datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}' # Write to disk self.write_to_disk(datafile_name, trajectory, alternatives, labels) return trajectory, alternatives, labels def write_to_disk(self, datafile_name, trajectory, alternatives, labels): self.data_root.mkdir(exist_ok=True, parents=True) with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f: new_key = len(f) f[f'trajectory_{new_key}'] = dict(alternatives=alternatives, trajectory=trajectory, labels=labels) if 'map' not in f: f['map'] = dict(map=self.map, name=self.map.name) @staticmethod def _remove_unequal(hom_dict): # We argue, that there will always be more non-homotopic routes than homotopic alternatives. # TODO: Otherwise introduce a second condition / loop hom_dict = hom_dict.copy() if len(hom_dict[0]) <= 1: return None counter = len(hom_dict) while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]): if counter == 0: counter = len(hom_dict) if counter in hom_dict: if len(hom_dict[counter]) == 0: del hom_dict[counter] else: del hom_dict[counter][-1] counter -= 1 return hom_dict