CNN Model Body
This commit is contained in:
@ -1,125 +1,132 @@
|
||||
from pathlib import Path
|
||||
|
||||
import copy
|
||||
from math import sqrt
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
|
||||
class Map(object):
|
||||
|
||||
white = [1, 255]
|
||||
black = [0]
|
||||
|
||||
def __copy__(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.map_array.shape
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.shape[1]
|
||||
|
||||
@property
|
||||
def as_graph(self):
|
||||
return self._G
|
||||
|
||||
@property
|
||||
def as_array(self):
|
||||
return self.map_array
|
||||
|
||||
def __init__(self, name='', array_like_map_representation=None):
|
||||
self.map_array: np.ndarray = array_like_map_representation
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
super(Map, self).__setattr__(key, value)
|
||||
if key == 'map_array' and self.map_array is not None:
|
||||
self._G = self._build_graph()
|
||||
|
||||
def _build_graph(self, full_neighbors=True):
|
||||
graph = nx.Graph()
|
||||
# Do checks in order: up - left - upperLeft - lowerLeft
|
||||
neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]
|
||||
|
||||
# Check pixels for their color (determine if walkable)
|
||||
for idx, value in np.ndenumerate(self.map_array):
|
||||
if value in self.white:
|
||||
y, x = idx
|
||||
# IF walkable, add node
|
||||
graph.add_node((y, x), count=0)
|
||||
# Fully connect to all surrounding neighbors
|
||||
for n, (xdif, ydif, weight) in enumerate(neighbors):
|
||||
# Differentiate between 8 and 4 neighbors
|
||||
if not full_neighbors and n >= 2:
|
||||
break
|
||||
|
||||
query_node = (y + ydif, x + xdif)
|
||||
if graph.has_node(query_node):
|
||||
graph.add_edge(idx, query_node, weight=weight)
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def from_image(cls, imagepath: Path):
|
||||
with Image.open(imagepath) as image:
|
||||
return cls(name=imagepath.name, array_like_map_representation=np.array(image))
|
||||
|
||||
def simple_trajectory_between(self, start, dest):
|
||||
vertices = list(nx.shortest_path(self._G, start, dest))
|
||||
trajectory = Trajectory(vertices)
|
||||
return trajectory
|
||||
|
||||
def get_valid_position(self):
|
||||
not_found, valid_position = True, (-9999, -9999)
|
||||
while not_found:
|
||||
valid_position = int(np.random.choice(self.height, 1)), int(np.random.choice(self.width, 1))
|
||||
if self._G.has_node(valid_position):
|
||||
not_found = False
|
||||
pass
|
||||
return valid_position
|
||||
|
||||
def get_trajectory_from_vertices(self, *args):
|
||||
coords = list()
|
||||
for start, dest in zip(args[:-1], args[1:]):
|
||||
coords.extend(nx.shortest_path(self._G, start, dest))
|
||||
return Trajectory(coords)
|
||||
|
||||
def get_random_trajectory(self):
|
||||
start = self.get_valid_position()
|
||||
dest = self.get_valid_position()
|
||||
return self.simple_trajectory_between(start, dest)
|
||||
|
||||
def are_homotopic(self, trajectory, other_trajectory):
|
||||
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
|
||||
raise TypeError
|
||||
polyline = trajectory.vertices.copy()
|
||||
polyline.extend(reversed(other_trajectory.vertices))
|
||||
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
ImageDraw.Draw(img).polygon(polyline, outline=1, fill=1)
|
||||
|
||||
a = (np.array(img) * self.map_array).sum()
|
||||
if a >= 1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def draw(self):
|
||||
fig, ax = plt.gcf(), plt.gca()
|
||||
# The standard colormaps also all have reversed versions.
|
||||
# They have the same names with _r tacked on to the end.
|
||||
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
||||
img = ax.imshow(self.as_array, cmap='Greys_r')
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
from pathlib import Path
|
||||
|
||||
import copy
|
||||
from math import sqrt
|
||||
from random import choice
|
||||
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from lib.objects.trajectory import Trajectory
|
||||
|
||||
|
||||
class Map(object):
|
||||
|
||||
# This setting is for Img mode "L" aka GreyScale Image; values: 0-255
|
||||
white = 255
|
||||
black = 0
|
||||
|
||||
def __copy__(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.map_array.shape
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.shape[0]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.shape[1]
|
||||
|
||||
@property
|
||||
def as_graph(self):
|
||||
return self._G
|
||||
|
||||
@property
|
||||
def as_array(self):
|
||||
return self.map_array
|
||||
|
||||
def __init__(self, name='', array_like_map_representation=None):
|
||||
self.map_array: np.ndarray = array_like_map_representation
|
||||
self.name = name
|
||||
pass
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
super(Map, self).__setattr__(key, value)
|
||||
if key == 'map_array' and self.map_array is not None:
|
||||
self._G = self._build_graph()
|
||||
|
||||
def _build_graph(self, full_neighbors=True):
|
||||
graph = nx.Graph()
|
||||
# Do checks in order: up - left - upperLeft - lowerLeft
|
||||
neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]
|
||||
|
||||
# Check pixels for their color (determine if walkable)
|
||||
for idx, value in np.ndenumerate(self.map_array):
|
||||
if value == self.white:
|
||||
try:
|
||||
y, x = idx
|
||||
except ValueError:
|
||||
y, x, channels = idx
|
||||
idx = (y, x)
|
||||
# IF walkable, add node
|
||||
graph.add_node((y, x), count=0)
|
||||
# Fully connect to all surrounding neighbors
|
||||
for n, (xdif, ydif, weight) in enumerate(neighbors):
|
||||
# Differentiate between 8 and 4 neighbors
|
||||
if not full_neighbors and n >= 2:
|
||||
break
|
||||
|
||||
query_node = (y + ydif, x + xdif)
|
||||
if graph.has_node(query_node):
|
||||
graph.add_edge(idx, query_node, weight=weight)
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def from_image(cls, imagepath: Path):
|
||||
with Image.open(imagepath) as image:
|
||||
# Turn the image to single Channel Greyscale
|
||||
if image.mode != 'L':
|
||||
image = image.convert('L')
|
||||
map_array = np.array(image)
|
||||
return cls(name=imagepath.name, array_like_map_representation=map_array)
|
||||
|
||||
def simple_trajectory_between(self, start, dest):
|
||||
vertices = list(nx.shortest_path(self._G, start, dest))
|
||||
trajectory = Trajectory(vertices)
|
||||
return trajectory
|
||||
|
||||
def get_valid_position(self):
|
||||
valid_position = choice(list(self._G.nodes))
|
||||
return valid_position
|
||||
|
||||
def get_trajectory_from_vertices(self, *args):
|
||||
coords = list()
|
||||
for start, dest in zip(args[:-1], args[1:]):
|
||||
coords.extend(nx.shortest_path(self._G, start, dest))
|
||||
return Trajectory(coords)
|
||||
|
||||
def get_random_trajectory(self):
|
||||
start = self.get_valid_position()
|
||||
dest = self.get_valid_position()
|
||||
return self.simple_trajectory_between(start, dest)
|
||||
|
||||
def are_homotopic(self, trajectory, other_trajectory):
|
||||
if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
|
||||
raise TypeError
|
||||
polyline = trajectory.vertices.copy()
|
||||
polyline.extend(reversed(other_trajectory.vertices))
|
||||
|
||||
img = Image.new('L', (self.height, self.width), 0)
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.polygon(polyline, outline=255, fill=255)
|
||||
|
||||
a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()
|
||||
if a >= 1:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def draw(self):
|
||||
fig, ax = plt.gcf(), plt.gca()
|
||||
# The standard colormaps also all have reversed versions.
|
||||
# They have the same names with _r tacked on to the end.
|
||||
# https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
|
||||
img = ax.imshow(self.as_array, cmap='Greys_r')
|
||||
return dict(img=img, fig=fig, ax=ax)
|
||||
|
Reference in New Issue
Block a user