Switched to Pooling and local aggregation
This commit is contained in:
@ -2,6 +2,7 @@ import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@ -11,6 +12,7 @@ from tqdm import trange
|
||||
from lib.objects.map import Map
|
||||
from lib.utils.parallel import run_n_in_parallel
|
||||
|
||||
|
||||
class Generator:
|
||||
|
||||
possible_modes = ['one_patching']
|
||||
@ -21,21 +23,28 @@ class Generator:
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
|
||||
trajectories_with_alternatives = list()
|
||||
for _ in trange(n, desc='Processing Trajectories'):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
||||
if not alternatives or labels:
|
||||
continue
|
||||
else:
|
||||
trajectories_with_alternatives.append(
|
||||
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
|
||||
)
|
||||
return trajectories_with_alternatives
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
|
||||
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
|
||||
kwargs.update(datafile_name=datafile_name, n=m)
|
||||
processes = processes if processes else mp.cpu_count()
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)]
|
||||
|
||||
def generate_alternatives(self, trajectory, output: Union[mp.
|
||||
Queue, None] = None, mode='one_patching'):
|
||||
# for _ in trange(n, desc='Processing Trajectories'):
|
||||
# self.write_n_alternatives(m, dataset_name, **kwargs)
|
||||
|
||||
# This line is for error catching only
|
||||
results = [r.get() for r in async_results]
|
||||
|
||||
with shelve.open(str(self.data_root / datafile_name)) as f:
|
||||
for datafile in self.data_root.glob(f'datafile_name*'):
|
||||
with shelve.open(str(datafile)) as sub_f:
|
||||
for key in sub_f.keys():
|
||||
f[len(f)] = sub_f[key]
|
||||
datafile.unlink()
|
||||
pass
|
||||
|
||||
def generate_alternatives(self, trajectory, mode='one_patching'):
|
||||
start, dest = trajectory.endpoints
|
||||
if mode == 'one_patching':
|
||||
patch = self.map.get_valid_position()
|
||||
@ -43,20 +52,17 @@ class Generator:
|
||||
else:
|
||||
raise RuntimeError(f'mode checking went wrong...')
|
||||
|
||||
if output:
|
||||
output.put(alternative)
|
||||
return alternative
|
||||
|
||||
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None,
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}'
|
||||
assert n is not None, f'n is not allowed to be None but was: {n}'
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
# Define an output queue
|
||||
#output = mp.Queue()
|
||||
|
||||
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
|
||||
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
|
||||
|
||||
# Get process results from the output queue
|
||||
#results = [output.get() for _ in range(n)]
|
||||
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
@ -87,16 +93,13 @@ class Generator:
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
|
||||
# Write to disk
|
||||
if dataset_name:
|
||||
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
|
||||
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
|
||||
return True
|
||||
|
||||
# Return
|
||||
return alternatives, labels
|
||||
|
||||
def write_to_disk(self, filepath, trajectory, alternatives, labels):
|
||||
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
|
||||
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
|
||||
self.data_root.mkdir(exist_ok=True, parents=True)
|
||||
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
new_key = len(f)
|
||||
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
|
||||
trajectory=trajectory,
|
||||
|
Reference in New Issue
Block a user