diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000..e99a9d5
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,22 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/hom_traj_gen.iml b/.idea/hom_traj_gen.iml
index 6958248..4b1d9c2 100644
--- a/.idea/hom_traj_gen.iml
+++ b/.idea/hom_traj_gen.iml
@@ -2,7 +2,7 @@
-
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 4bd9d59..f164374 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,5 +3,5 @@
-
+
\ No newline at end of file
diff --git a/data/tate.pik b/data/tate.pik
deleted file mode 100644
index 1f3c216..0000000
Binary files a/data/tate.pik and /dev/null differ
diff --git a/lib/preprocessing/generator.py b/lib/preprocessing/generator.py
index c65d0ec..574b1a7 100644
--- a/lib/preprocessing/generator.py
+++ b/lib/preprocessing/generator.py
@@ -2,6 +2,7 @@ import multiprocessing as mp
import pickle
import shelve
from collections import defaultdict
+from functools import partial
from pathlib import Path
from typing import Union
@@ -11,6 +12,7 @@ from tqdm import trange
from lib.objects.map import Map
from lib.utils.parallel import run_n_in_parallel
+
class Generator:
possible_modes = ['one_patching']
@@ -21,21 +23,28 @@ class Generator:
self.data_root = Path(data_root)
- def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs):
- trajectories_with_alternatives = list()
- for _ in trange(n, desc='Processing Trajectories'):
- trajectory = self.map.get_random_trajectory()
- alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
- if not alternatives or labels:
- continue
- else:
- trajectories_with_alternatives.append(
- dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
- )
- return trajectories_with_alternatives
+ def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs):
+ datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik'
+ kwargs.update(datafile_name=datafile_name, n=m)
+ processes = processes if processes else mp.cpu_count()
+ with mp.Pool(processes) as pool:
+ async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)]
- def generate_alternatives(self, trajectory, output: Union[mp.
- Queue, None] = None, mode='one_patching'):
+ # for _ in trange(n, desc='Processing Trajectories'):
+ # self.write_n_alternatives(m, dataset_name, **kwargs)
+
+ # This line is for error catching only
+ results = [r.get() for r in async_results]
+
+ with shelve.open(str(self.data_root / datafile_name)) as f:
+ for datafile in self.data_root.glob(f'datafile_name*'):
+ with shelve.open(str(datafile)) as sub_f:
+ for key in sub_f.keys():
+ f[len(f)] = sub_f[key]
+ datafile.unlink()
+ pass
+
+ def generate_alternatives(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.map.get_valid_position()
@@ -43,20 +52,17 @@ class Generator:
else:
raise RuntimeError(f'mode checking went wrong...')
- if output:
- output.put(alternative)
return alternative
- def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
- mode='one_patching', equal_samples=True, binary_check=True):
+ def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None,
+ mode='one_patching', equal_samples=True, binary_check=True):
+ assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}'
+ assert n is not None, f'n is not allowed to be None but was: {n}'
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
- # Define an output queue
- #output = mp.Queue()
- results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
+ trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
- # Get process results from the output queue
- #results = [output.get() for _ in range(n)]
+ results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
# label per homotopic class
homotopy_classes = defaultdict(list)
@@ -87,16 +93,13 @@ class Generator:
labels.extend([key] * len(homotopy_classes[key]))
# Write to disk
- if dataset_name:
- self.write_to_disk(dataset_name, trajectory, alternatives, labels)
+ subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
+ self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
+ return True
- # Return
- return alternatives, labels
-
- def write_to_disk(self, filepath, trajectory, alternatives, labels):
- dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik'
+ def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
self.data_root.mkdir(exist_ok=True, parents=True)
- with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
+ with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f:
new_key = len(f)
f[f'trajectory_{new_key}'] = dict(alternatives=alternatives,
trajectory=trajectory,
diff --git a/lib/utils/parallel.py b/lib/utils/parallel.py
index 1e2642f..f8a06a1 100644
--- a/lib/utils/parallel.py
+++ b/lib/utils/parallel.py
@@ -2,10 +2,12 @@ import multiprocessing as mp
import time
-def run_n_in_parallel(f, n, **kwargs):
+def run_n_in_parallel(f, n, processes=0, **kwargs):
+ processes = processes if processes else mp.cpu_count()
output = mp.Queue()
kwargs.update(output=output)
# Setup a list of processes that we want to run
+
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
# Run processes
results = []
diff --git a/main.py b/main.py
index 8312e22..1d770fa 100644
--- a/main.py
+++ b/main.py
@@ -65,7 +65,7 @@ config = Config.read_namespace(args)
# TESTING ONLY #
# =============================================================================
hparams = config.model_paramters
-dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000)
+dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=100000)
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
batch_size=hparams.data_param.batchsize,
num_workers=hparams.data_param.worker)