122 lines
4.9 KiB
Python
122 lines
4.9 KiB
Python
import multiprocessing as mp
|
|
import pickle
|
|
import shelve
|
|
from collections import defaultdict
|
|
|
|
from pathlib import Path
|
|
from typing import Union
|
|
|
|
from tqdm import trange
|
|
|
|
from lib.objects.map import Map
|
|
|
|
|
|
class Generator:
|
|
|
|
possible_modes = ['one_patching']
|
|
|
|
def __init__(self, data_root, map_obj, binary=True):
|
|
self.binary: bool = binary
|
|
self.map: Map = map_obj
|
|
|
|
self.data_root = Path(data_root)
|
|
|
|
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
|
|
trajectories_with_alternatives = list()
|
|
for _ in trange(n, desc='Processing Trajectories'):
|
|
trajectory = self.map.get_random_trajectory()
|
|
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
|
trajectories_with_alternatives.append(dict(trajectory=trajectory, alternatives=alternatives, labels=labels))
|
|
return trajectories_with_alternatives
|
|
|
|
def generate_alternatives(self, trajectory, output: Union[mp.
|
|
Queue, None] = None, mode='one_patching'):
|
|
start, dest = trajectory.endpoints
|
|
if mode == 'one_patching':
|
|
patch = self.map.get_valid_position()
|
|
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
|
|
else:
|
|
raise RuntimeError(f'mode checking went wrong...')
|
|
|
|
if output:
|
|
output.put(alternative)
|
|
return alternative
|
|
|
|
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
|
mode='one_patching', equal_samples=True):
|
|
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
|
# Define an output queue
|
|
output = mp.Queue()
|
|
# Setup a list of processes that we want to run
|
|
processes = [mp.Process(target=self.generate_alternatives,
|
|
kwargs=dict(trajectory=trajectory, output=output, mode=mode))
|
|
for _ in range(n)]
|
|
# Run processes
|
|
for p in processes:
|
|
p.start()
|
|
# Exit the completed processes
|
|
for p in processes:
|
|
p.join()
|
|
# Get process results from the output queue
|
|
results = [output.get() for _ in processes]
|
|
|
|
# label per homotopic class
|
|
homotopy_classes = defaultdict(list)
|
|
homotopy_classes[0].append(trajectory)
|
|
for i in range(len(results)):
|
|
alternative = results[i]
|
|
class_not_found, label = True, None
|
|
# check for homotopy class
|
|
for label in homotopy_classes.keys():
|
|
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
|
homotopy_classes[label].append(alternative)
|
|
class_not_found = False
|
|
break
|
|
if class_not_found:
|
|
label = len(homotopy_classes)
|
|
homotopy_classes[label].append(alternative)
|
|
|
|
# There should be as much homotopic samples as non-homotopic samples
|
|
if equal_samples:
|
|
homotopy_classes = self._remove_unequal(homotopy_classes)
|
|
|
|
# Compose lists of alternatives with labels
|
|
alternatives, labels = list(), list()
|
|
for key in homotopy_classes.keys():
|
|
alternatives.extend([homotopy_classes[key]])
|
|
labels.extend([key] * len(homotopy_classes[key]))
|
|
|
|
# Write to disk
|
|
if dataset_name:
|
|
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
|
|
|
|
# Return
|
|
return alternatives, labels
|
|
|
|
def write_to_disk(self, filepath, trajectory, alternatives, labels):
|
|
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
|
|
self.data_root.mkdir(exist_ok=True, parents=True)
|
|
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
|
new_key = len(f)
|
|
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
|
|
trajectory=trajectory,
|
|
labels=labels)
|
|
if 'map' not in f:
|
|
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
|
|
|
|
@staticmethod
|
|
def _remove_unequal(hom_dict):
|
|
hom_dict = hom_dict.copy()
|
|
|
|
counter = len(hom_dict)
|
|
while sum([len(hom_dict[class_id]) for class_id in range(len(hom_dict))]) > len(hom_dict[0]):
|
|
if counter > len(hom_dict):
|
|
counter = len(hom_dict)
|
|
if counter in hom_dict:
|
|
if len(hom_dict[counter]) == 0:
|
|
del hom_dict[counter]
|
|
else:
|
|
del hom_dict[counter][-1]
|
|
counter -= 1
|
|
return hom_dict
|