Steffen Illium 91ecf157d6 initial
2020-02-13 20:28:20 +01:00

126 lines
4.1 KiB
Python

from pathlib import Path
import copy
from math import sqrt
import numpy as np
from PIL import Image, ImageDraw
import networkx as nx
from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory
class Map(object):
white = [1, 255]
black = [0]
def __copy__(self):
return copy.deepcopy(self)
@property
def shape(self):
return self.map_array.shape
@property
def width(self):
return self.shape[0]
@property
def height(self):
return self.shape[1]
@property
def as_graph(self):
return self._G
@property
def as_array(self):
return self.map_array
def __init__(self, name='', array_like_map_representation=None):
self.map_array: np.ndarray = array_like_map_representation
self.name = name
pass
def __setattr__(self, key, value):
super(Map, self).__setattr__(key, value)
if key == 'map_array' and self.map_array is not None:
self._G = self._build_graph()
def _build_graph(self, full_neighbors=True):
graph = nx.Graph()
# Do checks in order: up - left - upperLeft - lowerLeft
neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]
# Check pixels for their color (determine if walkable)
for idx, value in np.ndenumerate(self.map_array):
if value in self.white:
y, x = idx
# IF walkable, add node
graph.add_node((y, x), count=0)
# Fully connect to all surrounding neighbors
for n, (xdif, ydif, weight) in enumerate(neighbors):
# Differentiate between 8 and 4 neighbors
if not full_neighbors and n >= 2:
break
query_node = (y + ydif, x + xdif)
if graph.has_node(query_node):
graph.add_edge(idx, query_node, weight=weight)
return graph
@classmethod
def from_image(cls, imagepath: Path):
with Image.open(imagepath) as image:
return cls(name=imagepath.name, array_like_map_representation=np.array(image))
def simple_trajectory_between(self, start, dest):
vertices = list(nx.shortest_path(self._G, start, dest))
trajectory = Trajectory(vertices)
return trajectory
def get_valid_position(self):
not_found, valid_position = True, (-9999, -9999)
while not_found:
valid_position = int(np.random.choice(self.height, 1)), int(np.random.choice(self.width, 1))
if self._G.has_node(valid_position):
not_found = False
pass
return valid_position
def get_trajectory_from_vertices(self, *args):
coords = list()
for start, dest in zip(args[:-1], args[1:]):
coords.extend(nx.shortest_path(self._G, start, dest))
return Trajectory(coords)
def get_random_trajectory(self):
start = self.get_valid_position()
dest = self.get_valid_position()
return self.simple_trajectory_between(start, dest)
def are_homotopic(self, trajectory, other_trajectory):
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
raise TypeError
polyline = trajectory.vertices.copy()
polyline.extend(reversed(other_trajectory.vertices))
img = Image.new('L', (self.height, self.width), 0)
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
a = (np.array(img) * self.map_array).sum()
if a >= 1:
return False
else:
return True
def draw(self):
fig, ax = plt.gcf(), plt.gca()
# The standard colormaps also all have reversed versions.
# They have the same names with _r tacked on to the end.
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
img = ax.imshow(self.as_array, cmap='Greys_r')
return dict(img=img, fig=fig, ax=ax)