import numpy as np
from sklearn.cluster import DBSCAN

import open3d as o3d

from pyod.models.lof import LOF

from torch_geometric.data import Data

from utils.project_settings import Classes


def polytopes_to_planes(pc):
    pc[(pc[:, 6] == float(Classes.Box)) or (pc[:, 6] == float(Classes.Polytope)), 6] = float(Classes.Plane)
    return pc


def mini_color_table(index, norm=True):
    colors = [
        [0.,0.,0.],
        [0.5000, 0.5400, 0.5300], [0.8900, 0.1500, 0.2100], [0.6400, 0.5800, 0.5000],
        [1.0000, 0.3800, 0.0100], [1.0000, 0.6600, 0.1400], [0.4980, 1.0000, 0.0000],
        [0.4980, 1.0000, 0.8314], [0.9412, 0.9725, 1.0000], [0.5412, 0.1686, 0.8863],
        [0.5765, 0.4392, 0.8588], [0.3600, 0.1400, 0.4300], [0.5600, 0.3700, 0.6000],
    ]

    color = colors[index % len(colors)]

    if not norm:
        color[0] *= 255
        color[1] *= 255
        color[2] *= 255

    return color


def cluster2Color(cluster, cluster_idx):
    colors = np.zeros(shape=(len(cluster), 3))
    point_idx = 0
    for _ in cluster:
        colors[point_idx, :] = mini_color_table(cluster_idx)
        point_idx += 1

    return colors


def label2color(labels):
    '''
    labels: np.ndarray with shape (n, )
    colors(return): np.ndarray with shape (n, 3)
    '''
    num = labels.shape[0]
    colors = np.zeros((num, 3))

    minl, maxl = np.min(labels), np.max(labels)
    for l in range(minl, maxl + 1):
        colors[labels == l, :] = mini_color_table(l)

    return colors


def read_pointcloud(path, delimiter=' ', hasHeader=True):
    with open(path, 'r') as f:
        if hasHeader:
            # Get rid of the Header
            _ = f.readline()
        # This iterates over all lines, splits them and converts values to floats. Will fail on wrong values.
        pc = [[float(x) for x in line.rstrip().split(delimiter)] for line in f if line != '']

    return np.asarray(pc)[:, :6]


def write_pointcloud(file, pc, numCols=6):
    np.savetxt(file, pc[:, :numCols], header=str(len(pc)) + ' ' + str(numCols), comments='')


def farthest_point_sampling(pts, K):
    if K > 0:
        if isinstance(pts, Data):
            pts = pts.pos.numpy()
        if pts.shape[0] < K:
            return pts
    else:
        return pts

    def calc_distances(p0, points):
        return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)

    farthest_pts = np.zeros((K, pts.shape[1]))
    farthest_pts[0] = pts[np.random.randint(len(pts))]
    distances = calc_distances(farthest_pts[0], pts)
    for i in range(1, K):
        farthest_pts[i] = pts[np.argmax(distances)]
        distances = np.minimum(distances, calc_distances(farthest_pts[i], pts))

    return farthest_pts


def cluster_per_column(pc, column):
    clusters = []
    for i in range(0, int(np.max(pc[:, column]))):
        cluster_pc = pc[pc[:, column] == i, :]
        clusters.append(cluster_pc)

    return clusters


def cluster_cubes(data, cluster_dims, max_points_per_cluster=-1, min_points_per_cluster=-1):
    if isinstance(data, Data):
        import torch
        candidate_list = list()
        if data.pos is not None:
            candidate_list.append(data.pos)
        if data.norm is not None:
            candidate_list.append(data.norm)
        if data.y is not None:
            candidate_list.append(data.y.double().unsqueeze(-1))

        data = torch.cat(candidate_list, dim=-1).numpy()

    if cluster_dims[0] == 1 and cluster_dims[1] == 1 and cluster_dims[2] == 1:
        print("no need to cluster.")
        return [farthest_point_sampling(data, max_points_per_cluster)]

    max = data[:, :3].max(axis=0)
    max += max * 0.01

    min = data[:, :3].min(axis=0)
    min -= min * 0.01

    size = (max - min)

    clusters = {}

    cluster_size = size / np.array(cluster_dims, dtype=np.float32)

    print('Min: ' + str(min) + ' Max: ' + str(max))
    print('Cluster Size: ' + str(cluster_size))

    for row in data:
        # print('Row: ' + str(row))

        cluster_pos = ((row[:3] - min) / cluster_size).astype(int)
        cluster_idx = cluster_dims[0] * cluster_dims[2] * cluster_pos[1] + cluster_dims[0] * cluster_pos[2] + \
                      cluster_pos[0]
        clusters.setdefault(cluster_idx, []).append(row)

        # Apply farthest point sampling to each cluster
    final_clusters = []
    for key, cluster in clusters.items():
        c = np.vstack(cluster)
        if c.shape[0] < min_points_per_cluster and -1 != min_points_per_cluster:
            continue

        if max_points_per_cluster is not -1:
            final_clusters.append(farthest_point_sampling(c, max_points_per_cluster))
        else:
            final_clusters.append(c)

    return final_clusters


def cluster_dbscan(data, selected_indices, eps, min_samples=5, metric='euclidean', algo='auto'):
    # print('Clustering. Min Samples: ' + str(min_samples) + ' EPS: ' + str(eps) + "Selected Indices: " + str(selected_indices))

    # 0,1,2 :   pos
    # 3,4,5 :   normal
    # 6:        type index
    # 7,8,9,10: type index one hot encoded
    # 11,12:    normal as angles

    db_res = DBSCAN(eps=eps, metric=metric, n_jobs=-1, min_samples=min_samples, algorithm=algo).fit(
        data[:, selected_indices])

    labels = db_res.labels_
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    n_noise = list(labels).count(-1)
    # print("Noise: " + str(n_noise) + " Clusters: " + str(n_clusters))

    clusters = {}
    for idx, l in enumerate(labels):
        if l == -1:
            continue
        clusters.setdefault(str(l), []).append(data[idx, :])

    npClusters = []
    for cluster in clusters.values():
        npClusters.append(np.array(cluster))

    return npClusters


def draw_clusters(clusters):
    clouds = []

    for cluster_idx, cluster in enumerate(clusters):
        cloud = o3d.PointCloud()
        cloud.points = o3d.Vector3dVector(cluster[:, :3])
        cloud.colors = o3d.Vector3dVector(cluster2Color(cluster, cluster_idx))
        clouds.append(cloud)

    o3d.draw_geometries(clouds)


def write_clusters(path, clusters, type_column=6):
    file = open(path, "w")
    file.write(str(len(clusters)) + "\n")

    for cluster in clusters:
        # print("Types: ", cluster[:, type_column])

        types = np.unique(cluster[:, type_column], axis=0).astype(int)

        def type_mapping(t):
            if t == 0:
                return 2
            elif t == 1:
                return 1
            elif t == 3:
                return 4
            return t

        types = np.array([type_mapping(t) for t in types])
        print("Types: {}, Points: {}".format(types, cluster.shape[0]))
        # draw_sample_data(cluster)

        np.savetxt(file, types.reshape(1, types.shape[0]), delimiter=';', header='', comments='', fmt='%i')
        np.savetxt(file, cluster[:, :6], header=str(len(cluster)) + ' ' + str(6), comments='')


def draw_sample_data(sample_data, colored_normals=False):
    cloud = o3d.PointCloud()
    cloud.points = o3d.Vector3dVector(sample_data[:, :3])
    cloud.colors = \
        o3d.Vector3dVector(label2color(sample_data[:, 6].astype(int)) if not colored_normals else sample_data[:, 3:6])

    o3d.draw_geometries([cloud])


def normalize_pointcloud(pc, factor=1.0):
    max = pc.max(axis=0)
    min = pc.min(axis=0)

    f = np.max([abs(max[0] - min[0]), abs(max[1] - min[1]), abs(max[2] - min[2])])

    pc[:, 0:3] /= (f * factor)
    pc[:, 3:6] /= (np.linalg.norm(pc[:, 3:6], ord=2, axis=1, keepdims=True))

    return pc


def hierarchical_clustering(data, selected_indices_0, selected_indices_1, eps, min_samples=5, metric='euclidean',
                            algo='auto'):
    total_clusters = []

    clusters = cluster_dbscan(data, selected_indices_0, eps, min_samples, metric=metric, algo=algo)

    for cluster in clusters:
        # cluster = normalize_pointcloud(cluster)
        sub_clusters = cluster_dbscan(cluster, selected_indices_1, eps, min_samples, metric=metric, algo=algo)
        total_clusters.extend(sub_clusters)

    return total_clusters


def filter_clusters(clusters, filter):
    filtered_clusters = []

    for c in clusters:
        if filter(c):
            filtered_clusters.append(c)

    return filtered_clusters


def split_outliers(pc, columns):
    clf = LOF()  # FeatureBagging() # detector_list=[LOF(), KNN()]
    clf.fit(pc[:, columns])

    # LOF, kNN

    return pc[clf.labels_ == 0], pc[clf.labels_ == 1]


def append_onehotencoded_type(data, factor=1.0):
    types = data[:, 6].astype(int)
    res = np.zeros((len(types), 8))
    res[np.arange(len(types)), types] = factor

    return np.column_stack((data, res))