2020-02-16 21:00:07 +01:00

125 lines
4.9 KiB
Python

import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
from pathlib import Path
from typing import Union
from tqdm import trange
from lib.objects.map import Map
from lib.utils.parallel import run_n_in_parallel
class Generator:
possible_modes = ['one_patching']
def __init__(self, data_root, map_obj, binary=True):
self.binary: bool = binary
self.map: Map = map_obj
self.data_root = Path(data_root)
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
trajectories_with_alternatives = list()
for _ in trange(n, desc='Processing Trajectories'):
trajectory = self.map.get_random_trajectory()
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
if not alternatives or labels:
continue
else:
trajectories_with_alternatives.append(
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
)
return trajectories_with_alternatives
def generate_alternatives(self, trajectory, output: Union[mp.
Queue, None] = None, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.map.get_valid_position()
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
if output:
output.put(alternative)
return alternative
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
mode='one_patching', equal_samples=True, binary_check=True):
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
# Define an output queue
#output = mp.Queue()
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
# Get process results from the output queue
#results = [output.get() for _ in range(n)]
# label per homotopic class
homotopy_classes = defaultdict(list)
homotopy_classes[0].append(trajectory)
for i in range(len(results)):
alternative = results[i]
class_not_found = True
# check for homotopy class
for label in homotopy_classes.keys():
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
homotopy_classes[label].append(alternative)
class_not_found = False
break
if class_not_found:
label = 1 if binary_check else len(homotopy_classes)
homotopy_classes[label].append(alternative)
# There should be as much homotopic samples as non-homotopic samples
if equal_samples:
homotopy_classes = self._remove_unequal(homotopy_classes)
if not homotopy_classes:
return None, None
# Compose lists of alternatives with labels
alternatives, labels = list(), list()
for key in homotopy_classes.keys():
alternatives.extend(homotopy_classes[key])
labels.extend([key] * len(homotopy_classes[key]))
# Write to disk
if dataset_name:
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
# Return
return alternatives, labels
def write_to_disk(self, filepath, trajectory, alternatives, labels):
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
self.data_root.mkdir(exist_ok=True, parents=True)
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = len(f)
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
trajectory=trajectory,
labels=labels)
if 'map' not in f:
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
@staticmethod
def _remove_unequal(hom_dict):
# We argue, that there will always be more non-homotopic routes than homotopic alternatives.
# TODO: Otherwise introduce a second condition / loop
hom_dict = hom_dict.copy()
if len(hom_dict[0]) <= 1:
return None
counter = len(hom_dict)
while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]):
if counter == 0:
counter = len(hom_dict)
if counter in hom_dict:
if len(hom_dict[counter]) == 0:
del hom_dict[counter]
else:
del hom_dict[counter][-1]
counter -= 1
return hom_dict