import numpy as np
import open3d as o3d
from sklearn.cluster import DBSCAN


def mini_color_table(index, norm=True):
    colors = [
        [0.5000, 0.5400, 0.5300], [0.8900, 0.1500, 0.2100], [0.6400, 0.5800, 0.5000],
        [1.0000, 0.3800, 0.0100], [1.0000, 0.6600, 0.1400], [0.4980, 1.0000, 0.0000],
        [0.4980, 1.0000, 0.8314], [0.9412, 0.9725, 1.0000], [0.5412, 0.1686, 0.8863],
        [0.5765, 0.4392, 0.8588], [0.3600, 0.1400, 0.4300], [0.5600, 0.3700, 0.6000],
    ]

    color = colors[index % len(colors)]

    if not norm:
        color[0] *= 255
        color[1] *= 255
        color[2] *= 255

    return color


def clusterToColor(cluster, cluster_idx):

    colors = np.zeros(shape=(len(cluster), 3))
    point_idx = 0
    for point in cluster:
        colors[point_idx, :] = mini_color_table(cluster_idx)
        point_idx += 1

    return colors


def read_pointcloud(path):
    file = open(path)

    header = file.readline()
    num_points = int(header.split()[0])
    pc = []

    for i in range(num_points):
        pc.append(list(float(s) for s in file.readline().split()))

    return np.array(pc)


def write_pointcloud(file, pc, numCols=6):
    np.savetxt(file, pc[:,:numCols], header=str(len(pc)) + ' ' + str(numCols), comments='')


def farthest_point_sampling(pts, K):

    if pts.shape[0] < K:
        return pts

    def calc_distances(p0, points):
        return ((p0[:3] - points[:, :3]) ** 2).sum(axis=1)

    farthest_pts = np.zeros((K, pts.shape[1]))
    farthest_pts[0] = pts[np.random.randint(len(pts))]
    distances = calc_distances(farthest_pts[0], pts)
    for i in range(1, K):
        farthest_pts[i] = pts[np.argmax(distances)]
        distances = np.minimum(distances, calc_distances(farthest_pts[i], pts))

    return farthest_pts


def cluster_per_column(pc, column):

    clusters = []
    for i in range(0, int(np.max(pc[:, column]))):
        cluster_pc = pc[pc[:, column] == i, :]
        clusters.append(cluster_pc)

    return clusters


def cluster_cubes(data, cluster_dims):

    max = data[:,:3].max(axis=0)
    max += max * 0.01

    min = data[:,:3].min(axis=0)
    min -= min * 0.01

    size = (max - min)

    clusters = {}

    cluster_size = size / np.array(cluster_dims, dtype=np.float32)

    print('Min: ' + str(min) + ' Max: ' + str(max))
    print('Cluster Size: ' + str(cluster_size))

    for row in data:

        # print('Row: ' + str(row))

        cluster_pos = ((row[:3] - min) / cluster_size).astype(int)
        cluster_idx = cluster_dims[0] * cluster_dims[2] * cluster_pos[1] + cluster_dims[0] * cluster_pos[2] + cluster_pos[0]
        clusters.setdefault(cluster_idx, []).append(row)

    # Apply farthest point sampling to each cluster
    for key, cluster in clusters.items():
        c = np.vstack(cluster)
        clusters[key] = c # farthest_point_sampling(c, max_points_per_cluster)

    return clusters.values()


def cluster_dbscan(data, selected_indices, eps, min_samples, metric='euclidean', algo='auto'):

    min_samples = min_samples * len(data);

    print('Clustering. Min Samples: ' + str(min_samples) + ' EPS: ' + str(eps) + "Selected Indices: " + str(selected_indices))

    # 0,1,2 :   pos
    # 3,4,5 :   normal
    # 6:        type index
    # 7,8,9,10: type index one hot encoded
    # 11,12:    normal as angles

    db_res = DBSCAN(eps=eps, metric=metric, n_jobs=-1, algorithm=algo, min_samples=min_samples).fit(data[:, selected_indices])


    labels = db_res.labels_
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    n_noise = list(labels).count(-1)
    print("Noise: " + str(n_noise) + " Clusters: " + str(n_clusters))

    clusters = {}
    for idx, l in enumerate(labels):
        if l is -1:
            continue
        clusters.setdefault(str(l), []).append(data[idx, :])


    npClusters = []
    for cluster in clusters.values():
        npClusters.append(np.array(cluster))

    return npClusters

def draw_clusters(clusters):

    clouds = []

    for cluster_idx, cluster in enumerate(clusters):

        cloud = o3d.PointCloud()
        cloud.points = o3d.Vector3dVector(cluster[:,:3])
        cloud.colors = o3d.Vector3dVector(clusterToColor(cluster, cluster_idx))
        clouds.append(cloud)

    o3d.draw_geometries(clouds)