diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..e99a9d5 --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/hom_traj_gen.iml b/.idea/hom_traj_gen.iml index 6958248..4b1d9c2 100644 --- a/.idea/hom_traj_gen.iml +++ b/.idea/hom_traj_gen.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 4bd9d59..f164374 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/data/tate.pik b/data/tate.pik deleted file mode 100644 index 1f3c216..0000000 Binary files a/data/tate.pik and /dev/null differ diff --git a/lib/preprocessing/generator.py b/lib/preprocessing/generator.py index c65d0ec..574b1a7 100644 --- a/lib/preprocessing/generator.py +++ b/lib/preprocessing/generator.py @@ -2,6 +2,7 @@ import multiprocessing as mp import pickle import shelve from collections import defaultdict +from functools import partial from pathlib import Path from typing import Union @@ -11,6 +12,7 @@ from tqdm import trange from lib.objects.map import Map from lib.utils.parallel import run_n_in_parallel + class Generator: possible_modes = ['one_patching'] @@ -21,21 +23,28 @@ class Generator: self.data_root = Path(data_root) - def generate_n_trajectories_m_alternatives(self, n, m, dataset_name='', **kwargs): - trajectories_with_alternatives = list() - for _ in trange(n, desc='Processing Trajectories'): - trajectory = self.map.get_random_trajectory() - alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs) - if not alternatives or labels: - continue - else: - trajectories_with_alternatives.append( - dict(trajectory=trajectory, alternatives=alternatives, labels=labels) - ) - return trajectories_with_alternatives + def generate_n_trajectories_m_alternatives(self, n, m, datafile_name, processes=0, **kwargs): + datafile_name = datafile_name if datafile_name.endswith('.pik') else f'{str(datafile_name)}.pik' + kwargs.update(datafile_name=datafile_name, n=m) + processes = processes if processes else mp.cpu_count() + with mp.Pool(processes) as pool: + async_results = [pool.apply_async(self.write_n_alternatives, kwds=kwargs) for _ in range(n)] - def generate_alternatives(self, trajectory, output: Union[mp. - Queue, None] = None, mode='one_patching'): + # for _ in trange(n, desc='Processing Trajectories'): + # self.write_n_alternatives(m, dataset_name, **kwargs) + + # This line is for error catching only + results = [r.get() for r in async_results] + + with shelve.open(str(self.data_root / datafile_name)) as f: + for datafile in self.data_root.glob(f'datafile_name*'): + with shelve.open(str(datafile)) as sub_f: + for key in sub_f.keys(): + f[len(f)] = sub_f[key] + datafile.unlink() + pass + + def generate_alternatives(self, trajectory, mode='one_patching'): start, dest = trajectory.endpoints if mode == 'one_patching': patch = self.map.get_valid_position() @@ -43,20 +52,17 @@ class Generator: else: raise RuntimeError(f'mode checking went wrong...') - if output: - output.put(alternative) return alternative - def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '', - mode='one_patching', equal_samples=True, binary_check=True): + def write_n_alternatives(self, n=None, datafile_name=None, trajectory=None, + mode='one_patching', equal_samples=True, binary_check=True): + assert datafile_name.endswith('.pik'), f'datafile_name does not end with .pik but: {datafile_name}' + assert n is not None, f'n is not allowed to be None but was: {n}' assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.' - # Define an output queue - #output = mp.Queue() - results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output) + trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory() - # Get process results from the output queue - #results = [output.get() for _ in range(n)] + results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)] # label per homotopic class homotopy_classes = defaultdict(list) @@ -87,16 +93,13 @@ class Generator: labels.extend([key] * len(homotopy_classes[key])) # Write to disk - if dataset_name: - self.write_to_disk(dataset_name, trajectory, alternatives, labels) + subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}' + self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels) + return True - # Return - return alternatives, labels - - def write_to_disk(self, filepath, trajectory, alternatives, labels): - dataset_name = filepath if filepath.endswith('.pik') else f'{filepath}.pik' + def write_to_disk(self, datafile_name, trajectory, alternatives, labels): self.data_root.mkdir(exist_ok=True, parents=True) - with shelve.open(str(self.data_root / dataset_name), protocol=pickle.HIGHEST_PROTOCOL) as f: + with shelve.open(str(self.data_root / datafile_name), protocol=pickle.HIGHEST_PROTOCOL) as f: new_key = len(f) f[f'trajectory_{new_key}'] = dict(alternatives=alternatives, trajectory=trajectory, diff --git a/lib/utils/parallel.py b/lib/utils/parallel.py index 1e2642f..f8a06a1 100644 --- a/lib/utils/parallel.py +++ b/lib/utils/parallel.py @@ -2,10 +2,12 @@ import multiprocessing as mp import time -def run_n_in_parallel(f, n, **kwargs): +def run_n_in_parallel(f, n, processes=0, **kwargs): + processes = processes if processes else mp.cpu_count() output = mp.Queue() kwargs.update(output=output) # Setup a list of processes that we want to run + processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)] # Run processes results = [] diff --git a/main.py b/main.py index 8312e22..1d770fa 100644 --- a/main.py +++ b/main.py @@ -65,7 +65,7 @@ config = Config.read_namespace(args) # TESTING ONLY # # ============================================================================= hparams = config.model_paramters -dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000) +dataset = TrajData('data', mapname='tate', alternatives=10000, trajectories=100000) dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True, batch_size=hparams.data_param.batchsize, num_workers=hparams.data_param.worker)