Switched to Pooling and local aggregation
This commit is contained in:
parent
0b30e7c22c
commit
db9f861d6c
22
.idea/deployment.xml
generated
Normal file
22
.idea/deployment.xml
generated
Normal file
@ -0,0 +1,22 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="On explicit save action" serverName="traj_gen-AiMachine">
|
||||
<serverData>
|
||||
<paths name="ErLoWa-AiMachine">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="traj_gen-AiMachine">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/hom_traj_gen" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myAutoUpload" value="ON_EXPLICIT_SAVE" />
|
||||
</component>
|
||||
</project>
|
2
.idea/hom_traj_gen.iml
generated
2
.idea/hom_traj_gen.iml
generated
@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.7 (traj_gen)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="traj_gen@AiMachine" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@ -3,5 +3,5 @@
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (traj_gen)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="traj_gen@AiMachine" project-jdk-type="Python SDK" />
|
||||
</project>
|
BIN
data/tate.pik
BIN
data/tate.pik
Binary file not shown.
@ -2,6 +2,7 @@ import multiprocessing as mp
|
||||
import pickle
|
||||
import shelve
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@ -11,6 +12,7 @@ from tqdm import trange
|
||||
from lib.objects.map import Map
|
||||
from lib.utils.parallel import run_n_in_parallel
|
||||
|
||||
|
||||
class Generator:
|
||||
|
||||
possible_modes = ['one_patching']
|
||||
@ -21,21 +23,28 @@ class Generator:
|
||||
|
||||
self.data_root = Path(data_root)
|
||||
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
|
||||
trajectories_with_alternatives = list()
|
||||
for _ in trange(n, desc='Processing Trajectories'):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
||||
if not alternatives or labels:
|
||||
continue
|
||||
else:
|
||||
trajectories_with_alternatives.append(
|
||||
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
|
||||
)
|
||||
return trajectories_with_alternatives
|
||||
def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
|
||||
datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
|
||||
kwargs.update(datafile_name=datafile_name, n=m)
|
||||
processes = processes if processes else mp.cpu_count()
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)]
|
||||
|
||||
def generate_alternatives(self, trajectory, output: Union[mp.
|
||||
Queue, None] = None, mode='one_patching'):
|
||||
# for _ in trange(n, desc='Processing Trajectories'):
|
||||
# self.write_n_alternatives(m, dataset_name, **kwargs)
|
||||
|
||||
# This line is for error catching only
|
||||
results = [r.get() for r in async_results]
|
||||
|
||||
with shelve.open(str(self.data_root / datafile_name)) as f:
|
||||
for datafile in self.data_root.glob(f'datafile_name*'):
|
||||
with shelve.open(str(datafile)) as sub_f:
|
||||
for key in sub_f.keys():
|
||||
f[len(f)] = sub_f[key]
|
||||
datafile.unlink()
|
||||
pass
|
||||
|
||||
def generate_alternatives(self, trajectory, mode='one_patching'):
|
||||
start, dest = trajectory.endpoints
|
||||
if mode == 'one_patching':
|
||||
patch = self.map.get_valid_position()
|
||||
@ -43,20 +52,17 @@ class Generator:
|
||||
else:
|
||||
raise RuntimeError(f'mode checking went wrong...')
|
||||
|
||||
if output:
|
||||
output.put(alternative)
|
||||
return alternative
|
||||
|
||||
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None,
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}'
|
||||
assert n is not None, f'n is not allowed to be None but was: {n}'
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
# Define an output queue
|
||||
#output = mp.Queue()
|
||||
|
||||
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
|
||||
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
|
||||
|
||||
# Get process results from the output queue
|
||||
#results = [output.get() for _ in range(n)]
|
||||
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
@ -87,16 +93,13 @@ class Generator:
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
|
||||
# Write to disk
|
||||
if dataset_name:
|
||||
self.write_to_disk(dataset_name, trajectory, alternatives, labels)
|
||||
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
|
||||
return True
|
||||
|
||||
# Return
|
||||
return alternatives, labels
|
||||
|
||||
def write_to_disk(self, filepath, trajectory, alternatives, labels):
|
||||
dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
|
||||
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
|
||||
self.data_root.mkdir(exist_ok=True, parents=True)
|
||||
with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
|
||||
new_key = len(f)
|
||||
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
|
||||
trajectory=trajectory,
|
||||
|
@ -2,10 +2,12 @@ import multiprocessing as mp
|
||||
import time
|
||||
|
||||
|
||||
def run_n_in_parallel(f, n, **kwargs):
|
||||
def run_n_in_parallel(f, n, processes=0, **kwargs):
|
||||
processes = processes if processes else mp.cpu_count()
|
||||
output = mp.Queue()
|
||||
kwargs.update(output=output)
|
||||
# Setup a list of processes that we want to run
|
||||
|
||||
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
|
||||
# Run processes
|
||||
results = []
|
||||
|
2
main.py
2
main.py
@ -65,7 +65,7 @@ config = Config.read_namespace(args)
|
||||
# TESTING ONLY #
|
||||
# =============================================================================
|
||||
hparams = config.model_paramters
|
||||
dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000)
|
||||
dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=100000)
|
||||
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
||||
batch_size=hparams.data_param.batchsize,
|
||||
num_workers=hparams.data_param.worker)
|
||||
|
Loading…
x
Reference in New Issue
Block a user