MP Data Generation, Single Process writing
This commit is contained in:
@ -25,16 +25,21 @@ class Generator:
|
||||
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
|
||||
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
|
||||
kwargs.update(datafile_name=datafile_name, n=m)
|
||||
processes = processes if processes else mp.cpu_count()
|
||||
kwargs.update(n=m)
|
||||
processes = processes if processes else mp.cpu_count() - 1
|
||||
mutex = mp.Lock()
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)]
|
||||
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
|
||||
|
||||
# for _ in trange(n, desc='Processing Trajectories'):
|
||||
# self.write_n_alternatives(m, dataset_name, **kwargs)
|
||||
|
||||
# This line is for error catching only
|
||||
results = [r.get() for r in async_results]
|
||||
for result_obj in async_results:
|
||||
trajectory, alternatives, labels = result_obj.get()
|
||||
mutex.acquire()
|
||||
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
||||
mutex.release()
|
||||
|
||||
with shelve.open(str(self.data_root / datafile_name)) as f:
|
||||
for datafile in self.data_root.glob(f'datafile_name*'):
|
||||
@ -54,9 +59,8 @@ class Generator:
|
||||
|
||||
return alternative
|
||||
|
||||
def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None,
|
||||
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None,
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}'
|
||||
assert n is not None, f'n is not allowed to be None but was: {n}'
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
|
||||
@ -84,18 +88,18 @@ class Generator:
|
||||
if equal_samples:
|
||||
homotopy_classes = self._remove_unequal(homotopy_classes)
|
||||
if not homotopy_classes:
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
# Compose lists of alternatives with labels
|
||||
alternatives, labels = list(), list()
|
||||
for key in homotopy_classes.keys():
|
||||
alternatives.extend(homotopy_classes[key])
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
|
||||
# Write to disk
|
||||
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
|
||||
return True
|
||||
if datafile_name:
|
||||
# Write to disk
|
||||
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
|
||||
return trajectory, alternatives, labels
|
||||
|
||||
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
|
||||
self.data_root.mkdir(exist_ok=True, parents=True)
|
||||
|
Reference in New Issue
Block a user