Refactoring
This commit is contained in:
117
datasets/preprocessing/generator.py
Normal file
117
datasets/preprocessing/generator.py
Normal file
@ -0,0 +1,117 @@
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from lib.objects.map import Map
|
||||
|
||||
|
||||
class Generator:
|
||||
|
||||
possible_modes = ['one_patching']
|
||||
|
||||
def __init__(self, data_root, map_obj, binary=True):
|
||||
self.binary: bool = binary
|
||||
self.map: Map = map_obj
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
|
||||
|
||||
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
|
||||
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
|
||||
kwargs.update(n=m)
|
||||
processes = processes if processes else mp.cpu_count() - 1
|
||||
mutex = mp.Lock()
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
|
||||
|
||||
for result_obj in tqdm(async_results, total=n, desc='Producing trajectories with Alternatives'):
|
||||
trajectory, alternatives, labels = result_obj.get()
|
||||
mutex.acquire()
|
||||
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
||||
mutex.release()
|
||||
|
||||
with shelve.open(str(self.data_root / datafile_name)) as f:
|
||||
for datafile in self.data_root.glob(f'datafile_name*'):
|
||||
with shelve.open(str(datafile)) as sub_f:
|
||||
for key in sub_f.keys():
|
||||
f[len(f)] = sub_f[key]
|
||||
datafile.unlink()
|
||||
pass
|
||||
|
||||
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None, is_sub_process=False,
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
assert n is not None, f'n is not allowed to be None but was: {n}'
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
|
||||
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
|
||||
|
||||
results = [self.map.generate_alternative(trajectory=trajectory, mode=mode) for _ in range(n)]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
homotopy_classes[0].append(trajectory)
|
||||
for i in range(len(results)):
|
||||
alternative = results[i]
|
||||
class_not_found = True
|
||||
# check for homotopy class
|
||||
for label in homotopy_classes.keys():
|
||||
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
||||
homotopy_classes[label].append(alternative)
|
||||
class_not_found = False
|
||||
break
|
||||
if class_not_found:
|
||||
label = 1 if binary_check else len(homotopy_classes)
|
||||
homotopy_classes[label].append(alternative)
|
||||
|
||||
# There should be as much homotopic samples as non-homotopic samples
|
||||
if equal_samples:
|
||||
homotopy_classes = self._remove_unequal(homotopy_classes)
|
||||
if not homotopy_classes:
|
||||
return None, None, None
|
||||
|
||||
# Compose lists of alternatives with labels
|
||||
alternatives, labels = list(), list()
|
||||
for key in homotopy_classes.keys():
|
||||
alternatives.extend(homotopy_classes[key])
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
if datafile_name:
|
||||
if is_sub_process:
|
||||
datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||
# Write to disk
|
||||
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
||||
return trajectory, alternatives, labels
|
||||
|
||||
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
|
||||
self.data_root.mkdir(exist_ok=True, parents=True)
|
||||
with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
new_key = len(f)
|
||||
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
|
||||
trajectory=trajectory,
|
||||
labels=labels)
|
||||
if 'map' not in f:
|
||||
f['map'] = dict(map=self.map, name=self.map.name)
|
||||
|
||||
@staticmethod
|
||||
def _remove_unequal(hom_dict):
|
||||
# We argue, that there will always be more non-homotopic routes than homotopic alternatives.
|
||||
# TODO: Otherwise introduce a second condition / loop
|
||||
hom_dict = hom_dict.copy()
|
||||
if len(hom_dict[0]) <= 1:
|
||||
return None
|
||||
counter = len(hom_dict)
|
||||
while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]):
|
||||
if counter == 0:
|
||||
counter = len(hom_dict)
|
||||
if counter in hom_dict:
|
||||
if len(hom_dict[counter]) == 0:
|
||||
del hom_dict[counter]
|
||||
else:
|
||||
del hom_dict[counter][-1]
|
||||
counter -= 1
|
||||
return hom_dict
|
Reference in New Issue
Block a user