Switched to Pooling and local aggregation

This commit is contained in:
Si11ium 2020-02-17 16:41:14 +01:00
parent 0b30e7c22c
commit db9f861d6c
7 changed files with 62 additions and 35 deletions

22
.idea/deployment.xml generated Normal file
View File

@ -0,0 +1,22 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
<serverData>
<paths name="ErLoWa-AiMachine">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="traj_gen-AiMachine">
<serverdata>
<mappings>
<mapping deploy="/hom_traj_gen" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
</component>
</project>

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
</project>

Binary file not shown.

View File

@ -2,6 +2,7 @@ import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Union
@ -11,6 +12,7 @@ from tqdm import trange
from lib.objects.map import Map
from lib.utils.parallel import run_n_in_parallel
class Generator:
possible_modes = ['one_patching']
@ -21,21 +23,28 @@ class Generator:
self.data_root = Path(data_root)
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
trajectories_with_alternatives = list()
for _ in trange(n, desc='Processing Trajectories'):
trajectory = self.map.get_random_trajectory()
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
if not alternatives or labels:
continue
else:
trajectories_with_alternatives.append(
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
)
return trajectories_with_alternatives
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
kwargs.update(datafile_name=datafile_name, n=m)
processes = processes if processes else mp.cpu_count()
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)]
def generate_alternatives(self, trajectory, output: Union[mp.
Queue, None] = None, mode='one_patching'):
# for _ in trange(n, desc='Processing Trajectories'):
# self.write_n_alternatives(m, dataset_name, **kwargs)
# This line is for error catching only
results = [r.get() for r in async_results]
with shelve.open(str(self.data_root / datafile_name)) as f:
for datafile in self.data_root.glob(f'datafile_name*'):
with shelve.open(str(datafile)) as sub_f:
for key in sub_f.keys():
f[len(f)] = sub_f[key]
datafile.unlink()
pass
def generate_alternatives(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.map.get_valid_position()
@ -43,20 +52,17 @@ class Generator:
else:
raise RuntimeError(f'mode checking went wrong...')
if output:
output.put(alternative)
return alternative
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None,
mode='one_patching', equal_samples=True, binary_check=True):
assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}'
assert n is not None, f'n is not allowed to be None but was: {n}'
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
# Define an output queue
#output = mp.Queue()
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
# Get process results from the output queue
#results = [output.get() for _ in range(n)]
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
# label per homotopic class
homotopy_classes = defaultdict(list)
@ -87,16 +93,13 @@ class Generator:
labels.extend([key] * len(homotopy_classes[key]))
# Write to disk
if dataset_name:
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
return True
# Return
return alternatives, labels
def write_to_disk(self, filepath, trajectory, alternatives, labels):
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
self.data_root.mkdir(exist_ok=True, parents=True)
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = len(f)
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
trajectory=trajectory,

View File

@ -2,10 +2,12 @@ import multiprocessing as mp
import time
def run_n_in_parallel(f, n, **kwargs):
def run_n_in_parallel(f, n, processes=0, **kwargs):
processes = processes if processes else mp.cpu_count()
output = mp.Queue()
kwargs.update(output=output)
# Setup a list of processes that we want to run
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
# Run processes
results = []

View File

@ -65,7 +65,7 @@ config = Config.read_namespace(args)
# TESTING ONLY #
# =============================================================================
hparams = config.model_paramters
dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000)
dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=100000)
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
batch_size=hparams.data_param.batchsize,
num_workers=hparams.data_param.worker)