import shelve
from pathlib import Path
from collections import UserDict


import copy
from math import sqrt
from random import choice

import numpy as np

from PIL import Image, ImageDraw
import networkx as nx
from matplotlib import pyplot as plt

from lib.objects.trajectory import Trajectory


class Map(object):

    # This setting is for Img mode "L" aka GreyScale Image; values: 0-255
    white = 255
    black = 0

    def __copy__(self):
        return copy.deepcopy(self)

    @property
    def shape(self):
        return self.map_array.shape

    @property
    def width(self):
        return self.shape[0]

    @property
    def height(self):
        return self.shape[1]

    @property
    def as_graph(self):
        return self._G

    @property
    def as_array(self):
        return self.map_array

    def __init__(self, name='', array_like_map_representation=None):
        self.map_array: np.ndarray = array_like_map_representation
        self.name = name
        pass

    def __setattr__(self, key, value):
        super(Map, self).__setattr__(key, value)
        if key == 'map_array' and self.map_array is not None:
            self._G = self._build_graph()

    def _build_graph(self, full_neighbors=True):
        graph = nx.Graph()
        # Do checks in order: up - left - upperLeft - lowerLeft
        neighbors = [(0, -1, 1), (-1, 0, 1), (-1, -1, sqrt(2)), (-1, 1, sqrt(2))]

        # Check pixels for their color (determine if walkable)
        for idx, value in np.ndenumerate(self.map_array):
            if value == self.white:
                try:
                    y, x = idx
                except ValueError:
                    y, x, channels = idx
                    idx = (y, x)
                # IF walkable, add node
                graph.add_node((y, x), count=0)
                # Fully connect to all surrounding neighbors
                for n, (xdif, ydif, weight) in enumerate(neighbors):
                    # Differentiate between 8 and 4 neighbors
                    if not full_neighbors and n >= 2:
                        break

                    query_node = (y + ydif, x + xdif)
                    if graph.has_node(query_node):
                        graph.add_edge(idx, query_node, weight=weight)
        return graph

    @classmethod
    def from_image(cls, imagepath: Path):
        with Image.open(imagepath) as image:
            # Turn the image to single Channel Greyscale
            if image.mode != 'L':
                image = image.convert('L')
            map_array = np.array(image)
            return cls(name=imagepath.name, array_like_map_representation=map_array)

    def simple_trajectory_between(self, start, dest):
        vertices = list(nx.shortest_path(self._G, start, dest))
        trajectory = Trajectory(vertices)
        return trajectory

    def get_valid_position(self):
        valid_position = choice(list(self._G.nodes))
        return valid_position

    def get_trajectory_from_vertices(self, *args):
        coords = list()
        for start, dest in zip(args[:-1], args[1:]):
            coords.extend(nx.shortest_path(self._G, start, dest))
        return Trajectory(coords)

    def get_random_trajectory(self):
        start = self.get_valid_position()
        dest = self.get_valid_position()
        return self.simple_trajectory_between(start, dest)

    def generate_alternative(self, trajectory, mode='one_patching'):
        start, dest = trajectory.endpoints
        if mode == 'one_patching':
            patch = self.get_valid_position()
            alternative = self.get_trajectory_from_vertices(start, patch, dest)
        else:
            raise RuntimeError(f'mode checking went wrong...')

        return alternative

    def are_homotopic(self, trajectory, other_trajectory):
        if not all(isinstance(x, Trajectory) for x in [trajectory, other_trajectory]):
            raise TypeError
        polyline = trajectory.vertices.copy()
        polyline.extend(reversed(other_trajectory.vertices))

        img = Image.new('L', (self.height, self.width), 0)
        draw = ImageDraw.Draw(img)
        draw.polygon(polyline, outline=255, fill=255)

        a = (np.array(img) * np.where(self.map_array == self.white, 0, 1)).sum()

        if a >= 1:
            return False
        else:
            return True

    def draw(self):
        fig, ax = plt.gcf(), plt.gca()
        # The standard colormaps also all have reversed versions.
        # They have the same names with _r tacked on to the end.
        # https: // matplotlib.org / api / pyplot_summary.html?highlight = colormaps
        img = ax.imshow(self.as_array, cmap='Greys_r')
        return dict(img=img, fig=fig, ax=ax)


class MapStorage(object):

    def __init__(self, map_root, load_all=False):
        self.data = dict()
        self.map_root = Path(map_root)
        if load_all:
            for map_file in self.map_root.glob('*.bmp'):
                _ = self[map_file.name]

    def __getitem__(self, item):
        if item in hasattr(self, item):
            return self.__getattribute__(item)
        else:
            with shelve.open(self.map_root / f'{item}.pik', flag='r') as d:
                self.__setattr__(item, d['map']['map'])
        return self[item]