from math import atan2
from typing import List, Tuple, Union

from matplotlib import pyplot as plt
from ml_lib import variables as V

import numpy as np


class Trajectory(object):

    @property
    def vertices(self):
        return self._vertices

    @property
    def xy_vertices(self):
        return [(x, y) for _, y, x in self._vertices]

    @property
    def endpoints(self):
        return self.start, self.dest

    @property
    def start(self):
        return self._vertices[0]

    @property
    def dest(self):
        return self._vertices[-1]

    @property
    def xs(self):
        return [x[2] for x in self._vertices]

    @property
    def ys(self):
        return [x[1] for x in self._vertices]

    @property
    def as_paired_list(self):
        return list(zip(self._vertices[:-1], self._vertices[1:]))

    def draw_in_array(self, shape):
        trajectory_space = np.zeros(shape).astype(np.float32)
        for index in self.vertices:
            trajectory_space[index] = V.WHITE
        return trajectory_space

    @property
    def np_vertices(self):
        return [np.array(vertice) for vertice in self._vertices]

    def __init__(self, vertices: Union[List[Tuple[int]], None] = None):
        assert any((isinstance(vertices, list), vertices is None))
        if vertices is not None:
            self._vertices = vertices
        pass

    def is_equal_to(self, other_trajectory):
        # ToDo: do further equality Checks here
        return self._vertices == other_trajectory.vertices

    def draw(self, highlights=True, label=None, **kwargs):
        if label is not None:
            kwargs.update(color='red' if label == V.HOMOTOPIC else 'green',
                          label='Homotopic' if label == V.HOMOTOPIC else 'Alternative',
                          lw=1)
        if highlights:
            kwargs.update(marker='o')
        fig, ax = plt.gcf(), plt.gca()
        img = plt.plot(self.xs, self.ys, **kwargs)
        return dict(img=img, fig=fig, ax=ax)

    def min_vertices(self, vertices):
        vertices, last_angle = [self.start], 0
        for (x1, y1), (x2, y2) in self.as_paired_list:
            current_angle = atan2(x1-x2, y1-y2)
            if current_angle != last_angle:
                vertices.append((x2, y2))
                last_angle = current_angle
            else:
                continue
        if vertices[-1] != self.dest:
            vertices.append(self.dest)
        return self.__class__(vertices=vertices)