2020-03-05 20:50:07 +01:00

191 lines
6.2 KiB
Python

import shelve
from pathlib import Path
import copy
from math import sqrt
from random import choice
import numpy as np
from PIL import Image, ImageDraw
import networkx as nx
from matplotlib import pyplot as plt
from lib.objects.trajectory import Trajectory
class Map(object):
# This setting is for Img mode "L" aka GreyScale Image; values: 0-255
white = 255
black = 0
def __copy__(self):
return copy.deepcopy(self)
@property
def shape(self):
return self.map_array.shape
@property
def width(self):
return self.shape[-2]
@property
def height(self):
return self.shape[-1]
@property
def as_graph(self):
return self._G
@property
def as_array(self):
return self.map_array
@property
def as_2d_array(self):
return self.map_array.squeeze()
def __init__(self, name='', array_like_map_representation=None):
if array_like_map_representation is not None:
if array_like_map_representation.ndim == 2:
array_like_map_representation = np.expand_dims(array_like_map_representation, axis=0)
assert array_like_map_representation.ndim == 3
self.map_array: np.ndarray = array_like_map_representation
self.name = name
pass
def __setattr__(self, key, value):
super(Map, self).__setattr__(key, value)
if key == 'map_array' and self.map_array is not None:
self._G = self._build_graph()
def _build_graph(self, full_neighbors=True):
graph = nx.Graph()
# Do checks in order: up - left - upperLeft - lowerLeft
neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]
# Check pixels for their color (determine if walkable)
for idx, value in np.ndenumerate(self.map_array):
if value == self.white:
# IF walkable, add node
graph.add_node(idx, count=0)
# Fully connect to all surrounding neighbors
for n, (xdif, ydif, weight) in enumerate(neighbors):
# Differentiate between 8 and 4 neighbors
if not full_neighbors and n >= 2:
break
# ToDO: make this explicite and less ugly
query_node = idx[:1] + (idx[1] + ydif,) + (idx[2] + xdif,)
if graph.has_node(query_node):
graph.add_edge(idx, query_node, weight=weight)
return graph
@classmethod
def from_image(cls, imagepath: Path, embedding_size=None):
with Image.open(imagepath) as image:
# Turn the image to single Channel Greyscale
if image.mode != 'L':
image = image.convert('L')
map_array = np.expand_dims(np.array(image), axis=0)
if embedding_size:
assert isinstance(embedding_size, tuple), f'embedding_size was of type: {type(embedding_size)}'
embedding = np.zeros(embedding_size)
embedding[:map_array.shape[0], :map_array.shape[1], :map_array.shape[2]] = map_array
map_array = embedding
return cls(name=imagepath.name, array_like_map_representation=map_array)
def simple_trajectory_between(self, start, dest):
vertices = list(nx.shortest_path(self._G, start, dest))
trajectory = Trajectory(vertices)
return trajectory
def get_valid_position(self):
valid_position = choice(list(self._G.nodes))
return valid_position
def get_trajectory_from_vertices(self, *args):
coords = list()
for start, dest in zip(args[:-1], args[1:]):
coords.extend(nx.shortest_path(self._G, start, dest))
return Trajectory(coords)
def get_random_trajectory(self):
simple_trajectory = None
while simple_trajectory is None:
try:
start = self.get_valid_position()
dest = self.get_valid_position()
simple_trajectory = self.simple_trajectory_between(start, dest)
except nx.exception.NetworkXNoPath:
pass
return simple_trajectory
def generate_alternative(self, trajectory, mode='one_patching'):
start, dest = trajectory.endpoints
alternative = None
while alternative is None:
try:
if mode == 'one_patching':
patch = self.get_valid_position()
alternative = self.get_trajectory_from_vertices(start, patch, dest)
else:
raise RuntimeError(f'mode checking went wrong...')
except nx.exception.NetworkXNoPath:
pass
return alternative
def are_homotopic(self, trajectory, other_trajectory):
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
raise TypeError
polyline = trajectory.xy_vertices
polyline.extend(reversed(other_trajectory.xy_vertices))
img = Image.new('L', (self.height, self.width), 0)
draw = ImageDraw.Draw(img)
draw.polygon(polyline, outline=1, fill=1)
a = (np.where(np.asarray(img) == self.white, 1, 0) * np.where(self.as_2d_array == self.white, 1, 0)).sum()
if a:
return False # Non-Homotoph
else:
return True # Homotoph
def draw(self):
fig, ax = plt.gcf(), plt.gca()
# The standard colormaps also all have reversed versions.
# They have the same names with _r tacked on to the end.
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
img = ax.imshow(self.as_2d_array, cmap='Greys_r')
return dict(img=img, fig=fig, ax=ax)
class MapStorage(object):
def __init__(self, map_root, load_all=False):
self.data = dict()
self.map_root = Path(map_root)
if load_all:
for map_file in self.map_root.glob('*.bmp'):
_ = self[map_file.name]
def __getitem__(self, item):
if item in hasattr(self, item):
return self.__getattribute__(item)
else:
with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
self.__setattr__(item, d['map']['map'])
return self[item]