CNN Model Body
This commit is contained in:
2
.idea/dictionaries/steffen.xml
generated
2
.idea/dictionaries/steffen.xml
generated
@ -2,6 +2,8 @@
|
|||||||
<dictionary name="steffen">
|
<dictionary name="steffen">
|
||||||
<words>
|
<words>
|
||||||
<w>conv</w>
|
<w>conv</w>
|
||||||
|
<w>homotopic</w>
|
||||||
|
<w>hyperparamter</w>
|
||||||
<w>numlayers</w>
|
<w>numlayers</w>
|
||||||
</words>
|
</words>
|
||||||
</dictionary>
|
</dictionary>
|
||||||
|
15
.idea/webResources.xml
generated
Normal file
15
.idea/webResources.xml
generated
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="WebResourcesPaths">
|
||||||
|
<contentEntries>
|
||||||
|
<entry url="file://$PROJECT_DIR$">
|
||||||
|
<entryData>
|
||||||
|
<resourceRoots>
|
||||||
|
<path value="file://$PROJECT_DIR$/res" />
|
||||||
|
<path value="file://$PROJECT_DIR$/data" />
|
||||||
|
</resourceRoots>
|
||||||
|
</entryData>
|
||||||
|
</entry>
|
||||||
|
</contentEntries>
|
||||||
|
</component>
|
||||||
|
</project>
|
@ -0,0 +1,21 @@
|
|||||||
|
from PIL import ImageDraw
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def are_homotopic(map_array, trajectory, other_trajectory):
|
||||||
|
|
||||||
|
polyline = trajectory.vertices.copy()
|
||||||
|
polyline.extend(reversed(other_trajectory.vertices))
|
||||||
|
|
||||||
|
height, width = map_array.shape
|
||||||
|
|
||||||
|
img = Image.new('L', (height, width), 0)
|
||||||
|
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
|
||||||
|
|
||||||
|
a = (np.array(img) * map_array).sum()
|
||||||
|
if a >= 1:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
@ -2,6 +2,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
from random import choice
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
@ -13,8 +15,9 @@ from lib.objects.trajectory import Trajectory
|
|||||||
|
|
||||||
class Map(object):
|
class Map(object):
|
||||||
|
|
||||||
white = [1, 255]
|
# This setting is for Img mode "L" aka GreyScale Image; values: 0-255
|
||||||
black = [0]
|
white = 255
|
||||||
|
black = 0
|
||||||
|
|
||||||
def __copy__(self):
|
def __copy__(self):
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
@ -56,8 +59,12 @@ class Map(object):
|
|||||||
|
|
||||||
# Check pixels for their color (determine if walkable)
|
# Check pixels for their color (determine if walkable)
|
||||||
for idx, value in np.ndenumerate(self.map_array):
|
for idx, value in np.ndenumerate(self.map_array):
|
||||||
if value in self.white:
|
if value == self.white:
|
||||||
|
try:
|
||||||
y, x = idx
|
y, x = idx
|
||||||
|
except ValueError:
|
||||||
|
y, x, channels = idx
|
||||||
|
idx = (y, x)
|
||||||
# IF walkable, add node
|
# IF walkable, add node
|
||||||
graph.add_node((y, x), count=0)
|
graph.add_node((y, x), count=0)
|
||||||
# Fully connect to all surrounding neighbors
|
# Fully connect to all surrounding neighbors
|
||||||
@ -74,7 +81,11 @@ class Map(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_image(cls, imagepath: Path):
|
def from_image(cls, imagepath: Path):
|
||||||
with Image.open(imagepath) as image:
|
with Image.open(imagepath) as image:
|
||||||
return cls(name=imagepath.name, array_like_map_representation=np.array(image))
|
# Turn the image to single Channel Greyscale
|
||||||
|
if image.mode != 'L':
|
||||||
|
image = image.convert('L')
|
||||||
|
map_array = np.array(image)
|
||||||
|
return cls(name=imagepath.name, array_like_map_representation=map_array)
|
||||||
|
|
||||||
def simple_trajectory_between(self, start, dest):
|
def simple_trajectory_between(self, start, dest):
|
||||||
vertices = list(nx.shortest_path(self._G, start, dest))
|
vertices = list(nx.shortest_path(self._G, start, dest))
|
||||||
@ -82,12 +93,7 @@ class Map(object):
|
|||||||
return trajectory
|
return trajectory
|
||||||
|
|
||||||
def get_valid_position(self):
|
def get_valid_position(self):
|
||||||
not_found, valid_position = True, (-9999, -9999)
|
valid_position = choice(list(self._G.nodes))
|
||||||
while not_found:
|
|
||||||
valid_position = int(np.random.choice(self.height, 1)), int(np.random.choice(self.width, 1))
|
|
||||||
if self._G.has_node(valid_position):
|
|
||||||
not_found = False
|
|
||||||
pass
|
|
||||||
return valid_position
|
return valid_position
|
||||||
|
|
||||||
def get_trajectory_from_vertices(self, *args):
|
def get_trajectory_from_vertices(self, *args):
|
||||||
@ -108,9 +114,10 @@ class Map(object):
|
|||||||
polyline.extend(reversed(other_trajectory.vertices))
|
polyline.extend(reversed(other_trajectory.vertices))
|
||||||
|
|
||||||
img = Image.new('L', (self.height, self.width), 0)
|
img = Image.new('L', (self.height, self.width), 0)
|
||||||
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
|
draw = ImageDraw.Draw(img)
|
||||||
|
draw.polygon(polyline, outline=255, fill=255)
|
||||||
|
|
||||||
a = (np.array(img) * self.map_array).sum()
|
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()
|
||||||
if a >= 1:
|
if a >= 1:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
@ -9,7 +9,7 @@ from typing import Union
|
|||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lib.objects.map import Map
|
from lib.objects.map import Map
|
||||||
|
from lib.utils.parallel import run_n_in_parallel
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
|
|
||||||
@ -26,7 +26,12 @@ class Generator:
|
|||||||
for _ in trange(n, desc='Processing Trajectories'):
|
for _ in trange(n, desc='Processing Trajectories'):
|
||||||
trajectory = self.map.get_random_trajectory()
|
trajectory = self.map.get_random_trajectory()
|
||||||
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
alternatives, labels = self.generate_n_alternatives(trajectory, m, dataset_name=dataset_name, **kwargs)
|
||||||
trajectories_with_alternatives.append(dict(trajectory=trajectory, alternatives=alternatives, labels=labels))
|
if not alternatives or labels:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
trajectories_with_alternatives.append(
|
||||||
|
dict(trajectory=trajectory, alternatives=alternatives, labels=labels)
|
||||||
|
)
|
||||||
return trajectories_with_alternatives
|
return trajectories_with_alternatives
|
||||||
|
|
||||||
def generate_alternatives(self, trajectory, output: Union[mp.
|
def generate_alternatives(self, trajectory, output: Union[mp.
|
||||||
@ -43,29 +48,22 @@ class Generator:
|
|||||||
return alternative
|
return alternative
|
||||||
|
|
||||||
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
def generate_n_alternatives(self, trajectory, n, dataset_name: Union[str, Path] = '',
|
||||||
mode='one_patching', equal_samples=True):
|
mode='one_patching', equal_samples=True, binary_check=True):
|
||||||
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
assert mode in self.possible_modes, f'Parameter "mode" must be either {self.possible_modes}, but was {mode}.'
|
||||||
# Define an output queue
|
# Define an output queue
|
||||||
output = mp.Queue()
|
#output = mp.Queue()
|
||||||
# Setup a list of processes that we want to run
|
|
||||||
processes = [mp.Process(target=self.generate_alternatives,
|
results = run_n_in_parallel(self.generate_alternatives, n, trajectory=trajectory, mode=mode) # , output=output)
|
||||||
kwargs=dict(trajectory=trajectory, output=output, mode=mode))
|
|
||||||
for _ in range(n)]
|
|
||||||
# Run processes
|
|
||||||
for p in processes:
|
|
||||||
p.start()
|
|
||||||
# Exit the completed processes
|
|
||||||
for p in processes:
|
|
||||||
p.join()
|
|
||||||
# Get process results from the output queue
|
# Get process results from the output queue
|
||||||
results = [output.get() for _ in processes]
|
#results = [output.get() for _ in range(n)]
|
||||||
|
|
||||||
# label per homotopic class
|
# label per homotopic class
|
||||||
homotopy_classes = defaultdict(list)
|
homotopy_classes = defaultdict(list)
|
||||||
homotopy_classes[0].append(trajectory)
|
homotopy_classes[0].append(trajectory)
|
||||||
for i in range(len(results)):
|
for i in range(len(results)):
|
||||||
alternative = results[i]
|
alternative = results[i]
|
||||||
class_not_found, label = True, None
|
class_not_found = True
|
||||||
# check for homotopy class
|
# check for homotopy class
|
||||||
for label in homotopy_classes.keys():
|
for label in homotopy_classes.keys():
|
||||||
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
if self.map.are_homotopic(homotopy_classes[label][0], alternative):
|
||||||
@ -73,17 +71,19 @@ class Generator:
|
|||||||
class_not_found = False
|
class_not_found = False
|
||||||
break
|
break
|
||||||
if class_not_found:
|
if class_not_found:
|
||||||
label = len(homotopy_classes)
|
label = 1 if binary_check else len(homotopy_classes)
|
||||||
homotopy_classes[label].append(alternative)
|
homotopy_classes[label].append(alternative)
|
||||||
|
|
||||||
# There should be as much homotopic samples as non-homotopic samples
|
# There should be as much homotopic samples as non-homotopic samples
|
||||||
if equal_samples:
|
if equal_samples:
|
||||||
homotopy_classes = self._remove_unequal(homotopy_classes)
|
homotopy_classes = self._remove_unequal(homotopy_classes)
|
||||||
|
if not homotopy_classes:
|
||||||
|
return None, None
|
||||||
|
|
||||||
# Compose lists of alternatives with labels
|
# Compose lists of alternatives with labels
|
||||||
alternatives, labels = list(), list()
|
alternatives, labels = list(), list()
|
||||||
for key in homotopy_classes.keys():
|
for key in homotopy_classes.keys():
|
||||||
alternatives.extend([homotopy_classes[key]])
|
alternatives.extend(homotopy_classes[key])
|
||||||
labels.extend([key] * len(homotopy_classes[key]))
|
labels.extend([key] * len(homotopy_classes[key]))
|
||||||
|
|
||||||
# Write to disk
|
# Write to disk
|
||||||
@ -106,11 +106,14 @@ class Generator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _remove_unequal(hom_dict):
|
def _remove_unequal(hom_dict):
|
||||||
|
# We argue, that there will always be more non-homotopic routes than homotopic alternatives.
|
||||||
|
# TODO: Otherwise introduce a second condition / loop
|
||||||
hom_dict = hom_dict.copy()
|
hom_dict = hom_dict.copy()
|
||||||
|
if len(hom_dict[0]) <= 1:
|
||||||
|
return None
|
||||||
counter = len(hom_dict)
|
counter = len(hom_dict)
|
||||||
while sum([len(hom_dict[class_id]) for class_id in range(len(hom_dict))]) > len(hom_dict[0]):
|
while sum([len(hom_dict[class_id]) for class_id in range(1, len(hom_dict))]) > len(hom_dict[0]):
|
||||||
if counter > len(hom_dict):
|
if counter == 0:
|
||||||
counter = len(hom_dict)
|
counter = len(hom_dict)
|
||||||
if counter in hom_dict:
|
if counter in hom_dict:
|
||||||
if len(hom_dict[counter]) == 0:
|
if len(hom_dict[counter]) == 0:
|
||||||
|
@ -5,6 +5,8 @@ from collections import defaultdict
|
|||||||
from configparser import ConfigParser
|
from configparser import ConfigParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lib.utils.model_io import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
def is_jsonable(x):
|
def is_jsonable(x):
|
||||||
import json
|
import json
|
||||||
@ -43,6 +45,10 @@ class Config(ConfigParser):
|
|||||||
return self._get_namespace_for_section('project')
|
return self._get_namespace_for_section('project')
|
||||||
###################################################
|
###################################################
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_paramters(self):
|
||||||
|
return ModelParameters(self.model, self.train, self.data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tags(self, ):
|
def tags(self, ):
|
||||||
return [f'{key}: {val}' for key, val in self.serializable.items()]
|
return [f'{key}: {val}' for key, val in self.serializable.items()]
|
||||||
|
@ -50,7 +50,7 @@ class Logger(LightningLoggerBase):
|
|||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.config = config
|
self.config = config
|
||||||
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
self._testtube_kwargs = dict(save_dir=self.outpath, version=self.version, name=self.name)
|
||||||
self._neptune_kwargs = dict(offline_mode=not self.debug,
|
self._neptune_kwargs = dict(offline_mode= self.debug,
|
||||||
api_key=self.config.project.neptune_key,
|
api_key=self.config.project.neptune_key,
|
||||||
project_name=self.project_name,
|
project_name=self.project_name,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
|
23
lib/utils/parallel.py
Normal file
23
lib/utils/parallel.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import multiprocessing as mp
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def run_n_in_parallel(f, n, **kwargs):
|
||||||
|
output = mp.Queue()
|
||||||
|
kwargs.update(output=output)
|
||||||
|
# Setup a list of processes that we want to run
|
||||||
|
processes = [mp.Process(target=f, kwargs=kwargs) for _ in range(n)]
|
||||||
|
# Run processes
|
||||||
|
results = []
|
||||||
|
for p in processes:
|
||||||
|
p.start()
|
||||||
|
while len(results) != n:
|
||||||
|
time.sleep(1)
|
||||||
|
# Get process results from the output queue
|
||||||
|
results.extend([output.get() for _ in processes])
|
||||||
|
|
||||||
|
# Exit the completed processes
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
return results
|
23
main.py
23
main.py
@ -61,11 +61,28 @@ main_arg_parser.add_argument("--project_neptune_key", type=str, default=os.geten
|
|||||||
args = main_arg_parser.parse_args()
|
args = main_arg_parser.parse_args()
|
||||||
config = Config.read_namespace(args)
|
config = Config.read_namespace(args)
|
||||||
|
|
||||||
# Trainer loading
|
################
|
||||||
|
# TESTING ONLY #
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
trainer = Trainer(logger=Logger(config, debug=True))
|
hparams = config.model_paramters
|
||||||
|
dataset = TrajData('data', mapname='tate', alternatives=100, trajectories=10000)
|
||||||
|
dataloader = DataLoader(dataset=dataset.train_dataset, shuffle=True,
|
||||||
|
batch_size=hparams.data_param.batchsize,
|
||||||
|
num_workers=hparams.data_param.worker)
|
||||||
|
|
||||||
|
# Logger
|
||||||
|
# =============================================================================
|
||||||
|
logger = Logger(config, debug=True)
|
||||||
|
|
||||||
|
# Trainer
|
||||||
|
# =============================================================================
|
||||||
|
trainer = Trainer(logger=logger)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
# =============================================================================
|
||||||
|
model = None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print(next(iter(train_dataloader)))
|
next(iter(dataloader))
|
||||||
pass
|
pass
|
||||||
|
Reference in New Issue
Block a user