CNN Classifier
This commit is contained in:
@ -5,6 +5,8 @@ from collections import defaultdict
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from lib.objects.map import Map
|
||||
|
||||
|
||||
@ -26,11 +28,7 @@ class Generator:
|
||||
with mp.Pool(processes) as pool:
|
||||
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
|
||||
|
||||
# for _ in trange(n, desc='Processing Trajectories'):
|
||||
# self.write_n_alternatives(m, dataset_name, **kwargs)
|
||||
|
||||
# This line is for error catching only
|
||||
for result_obj in async_results:
|
||||
for result_obj in tqdm(async_results, total=n, desc='Producing trajectories with Alternatives'):
|
||||
trajectory, alternatives, labels = result_obj.get()
|
||||
mutex.acquire()
|
||||
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
||||
@ -44,24 +42,14 @@ class Generator:
|
||||
datafile.unlink()
|
||||
pass
|
||||
|
||||
def generate_alternatives(self, trajectory, mode='one_patching'):
|
||||
start, dest = trajectory.endpoints
|
||||
if mode == 'one_patching':
|
||||
patch = self.map.get_valid_position()
|
||||
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
|
||||
else:
|
||||
raise RuntimeError(f'mode checking went wrong...')
|
||||
|
||||
return alternative
|
||||
|
||||
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None,
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None, is_sub_process=False,
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
assert n is not None, f'n is not allowed to be None but was: {n}'
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
|
||||
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
|
||||
|
||||
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
|
||||
results = [self.map.generate_alternative(trajectory=trajectory, mode=mode) for _ in range(n)]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
@ -91,9 +79,10 @@ class Generator:
|
||||
alternatives.extend(homotopy_classes[key])
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
if datafile_name:
|
||||
if is_sub_process:
|
||||
datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||
# Write to disk
|
||||
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
|
||||
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
|
||||
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
|
||||
return trajectory, alternatives, labels
|
||||
|
||||
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):
|
||||
|
Reference in New Issue
Block a user