CNN Model Body
This commit is contained in:
2
.idea/dictionaries/steffen.xml
generated
2
.idea/dictionaries/steffen.xml
generated
@ -2,6 +2,8 @@
|
||||
<dictionary name="steffen">
|
||||
<words>
|
||||
<w>conv</w>
|
||||
<w>homotopic</w>
|
||||
<w>hyperparamter</w>
|
||||
<w>numlayers</w>
|
||||
</words>
|
||||
</dictionary>
|
||||
|
15
.idea/webResources.xml
generated
Normal file
15
.idea/webResources.xml
generated
Normal file
@ -0,0 +1,15 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="WebResourcesPaths">
|
||||
<contentEntries>
|
||||
<entry url="file://$PROJECT_DIR$">
|
||||
<entryData>
|
||||
<resourceRoots>
|
||||
<path value="file://$PROJECT_DIR$/res" />
|
||||
<path value="file://$PROJECT_DIR$/data" />
|
||||
</resourceRoots>
|
||||
</entryData>
|
||||
</entry>
|
||||
</contentEntries>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,21 @@
|
||||
from PIL import ImageDraw
|
||||
from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def are_homotopic(map_array, trajectory, other_trajectory):
|
||||
|
||||
polyline = trajectory.vertices.copy()
|
||||
polyline.extend(reversed(other_trajectory.vertices))
|
||||
|
||||
height, width = map_array.shape
|
||||
|
||||
img = Image.new('L', (height, width), 0)
|
||||
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
|
||||
|
||||
a = (np.array(img) * map_array).sum()
|
||||
if a >= 1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
@ -2,6 +2,8 @@ from pathlib import Path
|
||||
|
||||
import copy
|
||||
from math import sqrt
|
||||
from random import choice
|
||||
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
@ -13,8 +15,9 @@ from lib.objects.trajectory import Trajectory
|
||||
|
||||
class Map(object):
|
||||
|
||||
white = [1, 255]
|
||||
black = [0]
|
||||
# This setting is for Img mode "L" aka GreyScale Image; values: 0-255
|
||||
white = 255
|
||||
black = 0
|
||||
|
||||
def __copy__(self):
|
||||
return copy.deepcopy(self)
|
||||
@ -56,8 +59,12 @@ class Map(object):
|
||||
|
||||
# Check pixels for their color (determine if walkable)
|
||||
for idx, value in np.ndenumerate(self.map_array):
|
||||
if value in self.white:
|
||||
y, x = idx
|
||||
if value == self.white:
|
||||
try:
|
||||
y, x = idx
|
||||
except ValueError:
|
||||
y, x, channels = idx
|
||||
idx = (y, x)
|
||||
# IF walkable, add node
|
||||
graph.add_node((y, x), count=0)
|
||||
# Fully connect to all surrounding neighbors
|
||||
@ -74,7 +81,11 @@ class Map(object):
|
||||
@classmethod
|
||||
def from_image(cls, imagepath: Path):
|
||||
with Image.open(imagepath) as image:
|
||||
return cls(name=imagepath.name, array_like_map_representation=np.array(image))
|
||||
# Turn the image to single Channel Greyscale
|
||||
if image.mode != 'L':
|
||||
image = image.convert('L')
|
||||
map_array = np.array(image)
|
||||
return cls(name=imagepath.name, array_like_map_representation=map_array)
|
||||
|
||||
def simple_trajectory_between(self, start, dest):
|
||||
vertices = list(nx.shortest_path(self._G, start, dest))
|
||||
@ -82,12 +93,7 @@ class Map(object):
|
||||
return trajectory
|
||||
|
||||
def get_valid_position(self):
|
||||
not_found, valid_position = True, (-9999, -9999)
|
||||
while not_found:
|
||||
valid_position = int(np.random.choice(self.height, 1)), int(np.random.choice(self.width, 1))
|
||||
if self._G.has_node(valid_position):
|
||||
not_found = False
|
||||
pass
|
||||
valid_position = choice(list(self._G.nodes))
|
||||
return valid_position
|
||||
|
||||
def get_trajectory_from_vertices(self, *args):
|
||||
@ -108,9 +114,10 @@ class Map(object):
|
||||
polyline.extend(reversed(other_trajectory.vertices))
|
||||
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.polygon(polyline, outline=255, fill=255)
|
||||
|
||||
a = (np.array(img) * self.map_array).sum()
|
||||
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()
|
||||
if a >= 1:
|
||||
return False
|
||||
else:
|
||||
|
@ -9,7 +9,7 @@ from typing import Union
|
||||
from tqdm import trange
|
||||
|
||||
from lib.objects.map import Map
|
||||
|
||||
from lib.utils.parallel import run_n_in_parallel
|
||||
|
||||
class Generator:
|
||||
|
||||
@ -26,7 +26,12 @@ class Generator:
|
||||
for _ in trange(n, desc='Processing Trajectories'):
|
||||
trajectory = self.map.get_random_trajectory()
|
||||
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
||||
trajectories_with_alternatives.append(dict(trajectory=trajectory, alternatives=alternatives, labels=labels))
|
||||
if not alternatives or labels:
|
||||
continue
|
||||
else:
|
||||
trajectories_with_alternatives.append(
|
||||
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
|
||||
)
|
||||
return trajectories_with_alternatives
|
||||
|
||||
def generate_alternatives(self, trajectory, output: Union[mp.
|
||||
@ -43,29 +48,22 @@ class Generator:
|
||||
return alternative
|
||||
|
||||
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
||||
mode='one_patching', equal_samples=True):
|
||||
mode='one_patching', equal_samples=True, binary_check=True):
|
||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||
# Define an output queue
|
||||
output = mp.Queue()
|
||||
# Setup a list of processes that we want to run
|
||||
processes = [mp.Process(target=self.generate_alternatives,
|
||||
kwargs=dict(trajectory=trajectory, output=output, mode=mode))
|
||||
for _ in range(n)]
|
||||
# Run processes
|
||||
for p in processes:
|
||||
p.start()
|
||||
# Exit the completed processes
|
||||
for p in processes:
|
||||
p.join()
|
||||
#output = mp.Queue()
|
||||
|
||||
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
|
||||
|
||||
# Get process results from the output queue
|
||||
results = [output.get() for _ in processes]
|
||||
#results = [output.get() for _ in range(n)]
|
||||
|
||||
# label per homotopic class
|
||||
homotopy_classes = defaultdict(list)
|
||||
homotopy_classes[0].append(trajectory)
|
||||
for i in range(len(results)):
|
||||
alternative = results[i]
|
||||
class_not_found, label = True, None
|
||||
class_not_found = True
|
||||
# check for homotopy class
|
||||
for label in homotopy_classes.keys():
|
||||
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
||||
@ -73,17 +71,19 @@ class Generator:
|
||||
class_not_found = False
|
||||
break
|
||||
if class_not_found:
|
||||
label = len(homotopy_classes)
|
||||
label = 1 if binary_check else len(homotopy_classes)
|
||||
homotopy_classes[label].append(alternative)
|
||||
|
||||
# There should be as much homotopic samples as non-homotopic samples
|
||||
if equal_samples:
|
||||
homotopy_classes = self._remove_unequal(homotopy_classes)
|
||||
if not homotopy_classes:
|
||||
return None, None
|
||||
|
||||
# Compose lists of alternatives with labels
|
||||
alternatives, labels = list(), list()
|
||||
for key in homotopy_classes.keys():
|
||||
alternatives.extend([homotopy_classes[key]])
|
||||
alternatives.extend(homotopy_classes[key])
|
||||
labels.extend([key] * len(homotopy_classes[key]))
|
||||
|
||||
# Write to disk
|
||||
@ -106,11 +106,14 @@ class Generator:
|
||||
|
||||
@staticmethod
|
||||
def _remove_unequal(hom_dict):
|
||||
# We argue, that there will always be more non-homotopic routes than homotopic alternatives.
|
||||
# TODO: Otherwise introduce a second condition / loop
|
||||
hom_dict = hom_dict.copy()
|
||||
|
||||
if len(hom_dict[0]) <= 1:
|
||||
return None
|
||||
counter = len(hom_dict)
|
||||
while sum([len(hom_dict[class_id]) for class_id in range(len(hom_dict))]) > len(hom_dict[0]):
|
||||
if counter > len(hom_dict):
|
||||
while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]):
|
||||
if counter == 0:
|
||||
counter = len(hom_dict)
|
||||
if counter in hom_dict:
|
||||
if len(hom_dict[counter]) == 0:
|
||||
|
@ -5,6 +5,8 @@ from collections import defaultdict
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
|
||||
from lib.utils.model_io import ModelParameters
|
||||
|
||||
|
||||
def is_jsonable(x):
|
||||
import json
|
||||
@ -43,6 +45,10 @@ class Config(ConfigParser):
|
||||
return self._get_namespace_for_section('project')
|
||||
###################################################
|
||||
|
||||
@property
|
||||
def model_paramters(self):
|
||||
return ModelParameters(self.model, self.train, self.data)
|
||||
|
||||
@property
|
||||
def tags(self, ):
|
||||
return [f'{key}: {val}' for key, val in self.serializable.items()]
|
||||
|
@ -50,7 +50,7 @@ class Logger(LightningLoggerBase):
|
||||
self.debug = debug
|
||||
self.config = config
|
||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||
self._neptune_kwargs = dict(offline_mode=not self.debug,
|
||||
self._neptune_kwargs = dict(offline_mode= self.debug,
|
||||
api_key=self.config.project.neptune_key,
|
||||
project_name=self.project_name,
|
||||
name=self.name,
|
||||
|
23
lib/utils/parallel.py
Normal file
23
lib/utils/parallel.py
Normal file
@ -0,0 +1,23 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
|
||||
|
||||
def run_n_in_parallel(f, n, **kwargs):
|
||||
output = mp.Queue()
|
||||
kwargs.update(output=output)
|
||||
# Setup a list of processes that we want to run
|
||||
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
|
||||
# Run processes
|
||||
results = []
|
||||
for p in processes:
|
||||
p.start()
|
||||
while len(results) != n:
|
||||
time.sleep(1)
|
||||
# Get process results from the output queue
|
||||
results.extend([output.get() for _ in processes])
|
||||
|
||||
# Exit the completed processes
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
return results
|
23
main.py
23
main.py
@ -61,11 +61,28 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
||||
args = main_arg_parser.parse_args()
|
||||
config = Config.read_namespace(args)
|
||||
|
||||
# Trainer loading
|
||||
################
|
||||
# TESTING ONLY #
|
||||
# =============================================================================
|
||||
trainer = Trainer(logger=Logger(config, debug=True))
|
||||
hparams = config.model_paramters
|
||||
dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000)
|
||||
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
||||
batch_size=hparams.data_param.batchsize,
|
||||
num_workers=hparams.data_param.worker)
|
||||
|
||||
# Logger
|
||||
# =============================================================================
|
||||
logger = Logger(config, debug=True)
|
||||
|
||||
# Trainer
|
||||
# =============================================================================
|
||||
trainer = Trainer(logger=logger)
|
||||
|
||||
# Model
|
||||
# =============================================================================
|
||||
model = None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(next(iter(train_dataloader)))
|
||||
next(iter(dataloader))
|
||||
pass
|
||||
|
Reference in New Issue
Block a user