Refactoring

This commit is contained in:
Si11ium
2020-04-08 12:04:04 +02:00
parent c7971c063f
commit 25c0e8e358
17 changed files with 0 additions and 21 deletions

0
objects/__init__.py Normal file
View File

193
objects/map.py Normal file
View File

@@ -0,0 +1,193 @@
from collections import UserDict
from pathlib import Path
import copy
from math import sqrt
from random import Random
import numpy as np
from PIL import Image, ImageDraw
import networkx as nx
from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory
import lib.variables as V
class Map(object):
def __copy__(self):
return copy.deepcopy(self)
@property
def shape(self):
return self.map_array.shape
@property
def width(self):
return self.shape[-2]
@property
def height(self):
return self.shape[-1]
@property
def as_graph(self):
return self._G
@property
def as_array(self):
return self.map_array
@property
def as_2d_array(self):
return self.map_array.squeeze()
def __init__(self, name='', array_like_map_representation=None):
if array_like_map_representation is not None:
array_like_map_representation = array_like_map_representation.astype(np.float32)
if array_like_map_representation.ndim == 2:
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
assert array_like_map_representation.ndim == 3
self.map_array: np.ndarray = array_like_map_representation
self.name = name
self.prng = Random()
pass
def seed(self, seed):
self.prng.seed(seed)
def __setattr__(self, key, value):
super(Map, self).__setattr__(key, value)
if key == 'map_array' and self.map_array is not None:
self._G = self._build_graph()
def _build_graph(self, full_neighbors=True):
graph = nx.Graph()
# Do checks in order: up - left - upperLeft - lowerLeft
neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]
# Check pixels for their color (determine if walkable)
for idx, value in np.ndenumerate(self.map_array):
if value != V.BLACK:
# IF walkable, add node
graph.add_node(idx, count=0)
# Fully connect to all surrounding neighbors
for n, (xdif, ydif, weight) in enumerate(neighbors):
# Differentiate between 8 and 4 neighbors
if not full_neighbors and n >= 2:
break
# ToDO: make this explicite and less ugly
query_node = idx[:1] + (idx[1] + ydif,) + (idx[2] + xdif,)
if graph.has_node(query_node):
graph.add_edge(idx, query_node, weight=weight)
return graph
@classmethod
def from_image(cls, imagepath: Path, embedding_size=None):
with Image.open(imagepath) as image:
# Turn the image to single Channel Greyscale
if image.mode != 'L':
image = image.convert('L')
map_array = np.expand_dims(np.array(image), axis=0)
if embedding_size:
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
embedding = np.full(embedding_size, V.BLACK)
embedding[:map_array.shape[0], :map_array.shape[1], :map_array.shape[2]] = map_array
map_array = embedding
return cls(name=imagepath.name, array_like_map_representation=map_array)
def simple_trajectory_between(self, start, dest):
vertices = list(nx.shortest_path(self._G, start, dest))
trajectory = Trajectory(vertices)
return trajectory
def get_valid_position(self):
valid_position = self.prng.choice(list(self._G.nodes))
return valid_position
def get_trajectory_from_vertices(self, *args):
coords = list()
for start, dest in zip(args[:-1], args[1:]):
coords.extend(nx.shortest_path(self._G, start, dest))
return Trajectory(coords)
def get_random_trajectory(self):
simple_trajectory = None
while simple_trajectory is None:
try:
start = self.get_valid_position()
dest = self.get_valid_position()
simple_trajectory = self.simple_trajectory_between(start, dest)
except nx.exception.NetworkXNoPath:
pass
return simple_trajectory
def generate_alternative(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
alternative = None
while alternative is None:
try:
if mode == 'one_patching':
patch = self.get_valid_position()
alternative = self.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
except nx.exception.NetworkXNoPath:
pass
return alternative
def are_homotopic(self, trajectory, other_trajectory):
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
raise TypeError
polyline = trajectory.xy_vertices
polyline.extend(reversed(other_trajectory.xy_vertices))
img = Image.new('L', (self.height, self.width), color=V.WHITE)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=V.BLACK, fill=V.BLACK)
binary_img = np.where(np.asarray(img).squeeze() == V.BLACK, 1, 0)
binary_map = np.where(self.as_2d_array == V.BLACK, 1, 0)
a = (binary_img * binary_map).sum()
if a:
return V.ALTERNATIVE # Non-Homotoph
else:
return V.HOMOTOPIC # Homotoph
def draw(self):
fig, ax = plt.gcf(), plt.gca()
# The standard colormaps also all have reversed versions.
# They have the same names with _r tacked on to the end.
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
img = ax.imshow(self.as_2d_array, cmap='Greys_r')
return dict(img=img, fig=fig, ax=ax)
class MapStorage(UserDict):
@property
def keys_list(self):
return list(super(MapStorage, self).keys())
def __init__(self, map_root, *args, **kwargs):
super(MapStorage, self).__init__(*args, **kwargs)
self.map_root = Path(map_root)
map_files = list(self.map_root.glob('*.bmp'))
self.max_map_size = (1, ) + tuple(
reversed(
tuple(
map(
max, *[Image.open(map_file).size for map_file in map_files])
)
)
)
for map_file in map_files:
current_map = Map.from_image(map_file, embedding_size=self.max_map_size)
self.__setitem__(map_file.name, current_map)

86
objects/trajectory.py Normal file
View File

@@ -0,0 +1,86 @@
from math import atan2
from typing import List, Tuple, Union
from matplotlib import pyplot as plt
from lib import variables as V
import numpy as np
class Trajectory(object):
@property
def vertices(self):
return self._vertices
@property
def xy_vertices(self):
return [(x, y) for _, y, x in self._vertices]
@property
def endpoints(self):
return self.start, self.dest
@property
def start(self):
return self._vertices[0]
@property
def dest(self):
return self._vertices[-1]
@property
def xs(self):
return [x[2] for x in self._vertices]
@property
def ys(self):
return [x[1] for x in self._vertices]
@property
def as_paired_list(self):
return list(zip(self._vertices[:-1], self._vertices[1:]))
def draw_in_array(self, shape):
trajectory_space = np.zeros(shape).astype(np.float32)
for index in self.vertices:
trajectory_space[index] = V.WHITE
return trajectory_space
@property
def np_vertices(self):
return [np.array(vertice) for vertice in self._vertices]
def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
assert any((isinstance(vertices, list), vertices is None))
if vertices is not None:
self._vertices = vertices
pass
def is_equal_to(self, other_trajectory):
# ToDo: do further equality Checks here
return self._vertices == other_trajectory.vertices
def draw(self, highlights=True, label=None, **kwargs):
if label is not None:
kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
label='Homotopic' if label == V.HOMOTOPIC else 'Alternative',
lw=1)
if highlights:
kwargs.update(marker='o')
fig, ax = plt.gcf(), plt.gca()
img = plt.plot(self.xs, self.ys, **kwargs)
return dict(img=img, fig=fig, ax=ax)
def min_vertices(self, vertices):
vertices, last_angle = [self.start], 0
for (x1, y1), (x2, y2) in self.as_paired_list:
current_angle = atan2(x1-x2, y1-y2)
if current_angle != last_angle:
vertices.append((x2, y2))
last_angle = current_angle
else:
continue
if vertices[-1] != self.dest:
vertices.append(self.dest)
return self.__class__(vertices=vertices)