Switched to Pooling and local aggregation

This commit is contained in:
Si11ium 2020-02-17 16:41:14 +01:00
parent 0b30e7c22c
commit db9f861d6c
7 changed files with 62 additions and 35 deletions

22
.idea/deployment.xml generated Normal file
View File

@ -0,0 +1,22 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
<serverData>
<paths name="ErLoWa-AiMachine">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="traj_gen-AiMachine">
<serverdata>
<mappings>
<mapping deploy="/hom_traj_gen" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
</component>
</project>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
</module> </module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings"> <component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" /> <option name="languageLevel" value="ES6" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
</project> </project>

Binary file not shown.

View File

@ -2,6 +2,7 @@ import multiprocessing as mp
import pickle import pickle
import shelve import shelve
from collections import defaultdict from collections import defaultdict
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@ -11,6 +12,7 @@ from tqdm import trange
from lib.objects.map import Map from lib.objects.map import Map
from lib.utils.parallel import run_n_in_parallel from lib.utils.parallel import run_n_in_parallel
class Generator: class Generator:
possible_modes = ['one_patching'] possible_modes = ['one_patching']
@ -21,21 +23,28 @@ class Generator:
self.data_root = Path(data_root) self.data_root = Path(data_root)
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs): def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
trajectories_with_alternatives = list() datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
for _ in trange(n, desc='Processing Trajectories'): kwargs.update(datafile_name=datafile_name, n=m)
trajectory = self.map.get_random_trajectory() processes = processes if processes else mp.cpu_count()
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs) with mp.Pool(processes) as pool:
if not alternatives or labels: async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)]
continue
else:
trajectories_with_alternatives.append(
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
)
return trajectories_with_alternatives
def generate_alternatives(self, trajectory, output: Union[mp. # for _ in trange(n, desc='Processing Trajectories'):
Queue, None] = None, mode='one_patching'): # self.write_n_alternatives(m, dataset_name, **kwargs)
# This line is for error catching only
results = [r.get() for r in async_results]
with shelve.open(str(self.data_root / datafile_name)) as f:
for datafile in self.data_root.glob(f'datafile_name*'):
with shelve.open(str(datafile)) as sub_f:
for key in sub_f.keys():
f[len(f)] = sub_f[key]
datafile.unlink()
pass
def generate_alternatives(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints start, dest = trajectory.endpoints
if mode == 'one_patching': if mode == 'one_patching':
patch = self.map.get_valid_position() patch = self.map.get_valid_position()
@ -43,20 +52,17 @@ class Generator:
else: else:
raise RuntimeError(f'mode checking went wrong...') raise RuntimeError(f'mode checking went wrong...')
if output:
output.put(alternative)
return alternative return alternative
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '', def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None,
mode='one_patching', equal_samples=True, binary_check=True): mode='one_patching', equal_samples=True, binary_check=True):
assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}'
assert n is not None, f'n is not allowed to be None but was: {n}'
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.' assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
# Define an output queue
#output = mp.Queue()
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output) trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
# Get process results from the output queue results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
#results = [output.get() for _ in range(n)]
# label per homotopic class # label per homotopic class
homotopy_classes = defaultdict(list) homotopy_classes = defaultdict(list)
@ -87,16 +93,13 @@ class Generator:
labels.extend([key] * len(homotopy_classes[key])) labels.extend([key] * len(homotopy_classes[key]))
# Write to disk # Write to disk
if dataset_name: subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
self.write_to_disk(dataset_name, trajectory, alternatives, labels) self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
return True
# Return def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
return alternatives, labels
def write_to_disk(self, filepath, trajectory, alternatives, labels):
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
self.data_root.mkdir(exist_ok=True, parents=True) self.data_root.mkdir(exist_ok=True, parents=True)
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f: with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = len(f) new_key = len(f)
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives, f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
trajectory=trajectory, trajectory=trajectory,

View File

@ -2,10 +2,12 @@ import multiprocessing as mp
import time import time
def run_n_in_parallel(f, n, **kwargs): def run_n_in_parallel(f, n, processes=0, **kwargs):
processes = processes if processes else mp.cpu_count()
output = mp.Queue() output = mp.Queue()
kwargs.update(output=output) kwargs.update(output=output)
# Setup a list of processes that we want to run # Setup a list of processes that we want to run
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)] processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
# Run processes # Run processes
results = [] results = []

View File

@ -65,7 +65,7 @@ config = Config.read_namespace(args)
# TESTING ONLY # # TESTING ONLY #
# ============================================================================= # =============================================================================
hparams = config.model_paramters hparams = config.model_paramters
dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000) dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=100000)
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True, dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
batch_size=hparams.data_param.batchsize, batch_size=hparams.data_param.batchsize,
num_workers=hparams.data_param.worker) num_workers=hparams.data_param.worker)