Steffen Illium 91ecf157d6 initial
2020-02-13 20:28:20 +01:00

122 lines
4.9 KiB
Python

import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
from pathlib import Path
from typing import Union
from tqdm import trange
from lib.objects.map import Map
class Generator:
possible_modes = ['one_patching']
def __init__(self, data_root, map_obj, binary=True):
self.binary: bool = binary
self.map: Map = map_obj
self.data_root = Path(data_root)
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
trajectories_with_alternatives = list()
for _ in trange(n, desc='Processing Trajectories'):
trajectory = self.map.get_random_trajectory()
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
trajectories_with_alternatives.append(dict(trajectory=trajectory, alternatives=alternatives, labels=labels))
return trajectories_with_alternatives
def generate_alternatives(self, trajectory, output: Union[mp.
Queue, None] = None, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.map.get_valid_position()
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
if output:
output.put(alternative)
return alternative
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
mode='one_patching', equal_samples=True):
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
# Define an output queue
output = mp.Queue()
# Setup a list of processes that we want to run
processes = [mp.Process(target=self.generate_alternatives,
kwargs=dict(trajectory=trajectory, output=output, mode=mode))
for _ in range(n)]
# Run processes
for p in processes:
p.start()
# Exit the completed processes
for p in processes:
p.join()
# Get process results from the output queue
results = [output.get() for _ in processes]
# label per homotopic class
homotopy_classes = defaultdict(list)
homotopy_classes[0].append(trajectory)
for i in range(len(results)):
alternative = results[i]
class_not_found, label = True, None
# check for homotopy class
for label in homotopy_classes.keys():
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
homotopy_classes[label].append(alternative)
class_not_found = False
break
if class_not_found:
label = len(homotopy_classes)
homotopy_classes[label].append(alternative)
# There should be as much homotopic samples as non-homotopic samples
if equal_samples:
homotopy_classes = self._remove_unequal(homotopy_classes)
# Compose lists of alternatives with labels
alternatives, labels = list(), list()
for key in homotopy_classes.keys():
alternatives.extend([homotopy_classes[key]])
labels.extend([key] * len(homotopy_classes[key]))
# Write to disk
if dataset_name:
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
# Return
return alternatives, labels
def write_to_disk(self, filepath, trajectory, alternatives, labels):
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
self.data_root.mkdir(exist_ok=True, parents=True)
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = len(f)
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
trajectory=trajectory,
labels=labels)
if 'map' not in f:
f['map'] = dict(map=self.map, name=f'map_{self.map.name}')
@staticmethod
def _remove_unequal(hom_dict):
hom_dict = hom_dict.copy()
counter = len(hom_dict)
while sum([len(hom_dict[class_id]) for class_id in range(len(hom_dict))]) > len(hom_dict[0]):
if counter > len(hom_dict):
counter = len(hom_dict)
if counter in hom_dict:
if len(hom_dict[counter]) == 0:
del hom_dict[counter]
else:
del hom_dict[counter][-1]
counter -= 1
return hom_dict