CNN Classifier

This commit is contained in:
Si11ium
2020-02-21 09:44:09 +01:00
parent 537e5371c9
commit 7b3f781d19
12 changed files with 247 additions and 109 deletions

View File

@ -5,6 +5,8 @@ from collections import defaultdict
from pathlib import Path
from tqdm import tqdm
from lib.objects.map import Map
@ -26,11 +28,7 @@ class Generator:
with mp.Pool(processes) as pool:
async_results = [pool.apply_async(self.generate_n_alternatives, kwds=kwargs) for _ in range(n)]
# for _ in trange(n, desc='Processing Trajectories'):
# self.write_n_alternatives(m, dataset_name, **kwargs)
# This line is for error catching only
for result_obj in async_results:
for result_obj in tqdm(async_results, total=n, desc='Producing trajectories with Alternatives'):
trajectory, alternatives, labels = result_obj.get()
mutex.acquire()
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
@ -44,24 +42,14 @@ class Generator:
datafile.unlink()
pass
def generate_alternatives(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
if mode == 'one_patching':
patch = self.map.get_valid_position()
alternative = self.map.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
return alternative
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None,
mode='one_patching', equal_samples=True, binary_check=True):
def generate_n_alternatives(self, n=None, datafile_name='', trajectory=None, is_sub_process=False,
mode='one_patching', equal_samples=True, binary_check=True):
assert n is not None, f'n is not allowed to be None but was: {n}'
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
trajectory = trajectory if trajectory is not None else self.map.get_random_trajectory()
results = [self.generate_alternatives(trajectory=trajectory, mode=mode) for _ in range(n)]
results = [self.map.generate_alternative(trajectory=trajectory, mode=mode) for _ in range(n)]
# label per homotopic class
homotopy_classes = defaultdict(list)
@ -91,9 +79,10 @@ class Generator:
alternatives.extend(homotopy_classes[key])
labels.extend([key] * len(homotopy_classes[key]))
if datafile_name:
if is_sub_process:
datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
# Write to disk
subprocess_datafile_name = f'{str(datafile_name)}_{mp.current_process().pid}'
self.write_to_disk(subprocess_datafile_name, trajectory, alternatives, labels)
self.write_to_disk(datafile_name, trajectory, alternatives, labels)
return trajectory, alternatives, labels
def write_to_disk(self, datafile_name, trajectory, alternatives, labels):