# Keras Utility Imports
from keras import backend as K
from keras.utils import plot_model

from collections import defaultdict
from random import shuffle

# Numpy
import numpy as np

# Maths
from math import sqrt
from sklearn.cluster import DBSCAN, KMeans
from sklearn.metrics import silhouette_samples, silhouette_score
from sklearn.manifold.t_sne import TSNE
from sklearn.decomposition import PCA

# Plotting
import matplotlib.pyplot as plt
from PIL import ImageDraw, Image
import matplotlib.cm as cm
"""
UseFull Links:

Keras Timedistributed Wrapper is applysing the SAME Layer, which means same WEIGHTS:
http://machinelearningmastery.com/timedistributed-layer-for-long-short-term-memory-networks-in-python/

"""


class Trainer(object):
    def __init__(self, mode, trackCollection, classes, categorical_distribution,
                 batchSize=400, timesteps=5, filters=0, rotating=False):
        if mode.lower() not in ['gumble', 'vae', 'refined']:
            raise ValueError('Needs to be eather "gumble", "vae" or "refined"')

        self.mode = mode.lower()
        self.tc = trackCollection if isinstance(trackCollection, list) else list(trackCollection)

        self.rotating = rotating
        self.timesteps = timesteps

        self.classes = classes
        self.cD = categorical_distribution
        self.batchSize = batchSize
        self.epochs = 100
        self.epsilon_std = 0.01
        self.tau = K.variable(5.0, name="temperature")
        self.anneal_rate = 0.0003
        self.min_temperature = 0.5
        _, _, height, width, _ = self.tc[0].as_n_sample_4D(
            self.timesteps, for_track=list(self.tc[0].keys())[0]).shape

        self.height = height
        self.width = width
        self.original_dim = self.timesteps * self.width * self.height * 1  # = 5*30*30 = 4500px
        self.filters = int(sqrt(self.width ** 2 + self.height ** 2)) // 2 if not filters else filters

        self.trained = False
        self.model = None
        self.encoder = None
        self.generator = None

    def set_model(self, model, loss, optimizer='adagrad'):
        self.model = model
        self.model.compile(optimizer=optimizer, loss=loss)
        self.model.summary()

    def set_generator(self, generator):
        self.generator = generator

    def set_encoder(self, encoder):
        self.encoder = encoder

    def load_weights(self, fileName):
        self.model.load_weights(fileName)
        self.trained = True

    def save_weights(self, fileName):
        self.model.save_weights(fileName)

    def train(self, fileName=None):
        for i in range(self.epochs):
            tracklists = [list(x.keys()) for x in self.tc] # TODO: add additional shuffeling!!!
            for keys in zip(*tracklists):
                data = np.empty(0)
                while data.shape[0] < self.batchSize:
                    tempData = [self.tc[idx].as_n_sample_4D(self.timesteps, in_walk_dir=self.rotating, for_track=key)
                                for idx, key in enumerate(keys)]
                    tempData = np.row_stack(tempData)
                    if data.shape[0] == 0:
                        data = tempData
                    else:
                        data = np.row_stack((data, tempData))
                np.random.shuffle(data)
                smoothing = data.shape[0] // self.batchSize * self.batchSize
                if smoothing:
                    data = data[:smoothing]
                self.model.fit(data, data, shuffle=True, epochs=1, batch_size=self.batchSize)
                data = None

            K.set_value(self.tau,
                        np.max([K.get_value(self.tau) * np.exp(- self.anneal_rate * i),
                                self.min_temperature]))
            if fileName:
                self.save_weights(fileName)

    def plot_model(self, filename, show_shapes=True, show_layer_names=True):
        plot_model(self.model, filename, show_shapes=show_shapes, show_layer_names=show_layer_names)

    def color_random_track(self, completeSequence=False, show=True, fileName='', primaryTC=0,
                           nClusters=0, multiPath=False, cMode='kmeans', aMode='none'):
        if not self.trained:
            raise EnvironmentError('Please train this Model first!')

        self.tc[primaryTC].map.refresh_random_clock()
        track = self.tc[primaryTC].map.return_random_path()
        if fileName:
            fileName = fileName if fileName.endswith('.tif') else '%s.tif' % fileName

        return self.color_track(track, show=show, completeSequence=completeSequence, primaryTC=primaryTC,
                                fileName=fileName, nClusters=nClusters, multiPath=multiPath, cMode=cMode, aMode=aMode)

    def color_track(self, track, completeSequence=False, show=True, fileName='', nClusters=0, multiPath=False,
                    cMode='kmeans', aMode='none', primaryTC=0):
        if not self.trained:
            raise EnvironmentError('Please train this Model first!')
        isoArray = self.tc[primaryTC].map.isovists.get_items_for_track(track, dim='full', in_walk_dir=True).swapaxes(0, 2)[
            ..., None].transpose((0, 2, 1, 3))
        smoothing = (isoArray.shape[0] // self.timesteps) * self.timesteps
        isoArray = isoArray[:smoothing]
        sequenceArray = np.array([isoArray[i:i+self.timesteps] for i in range(len(isoArray)-self.timesteps)])
        dummyData = self.tc[primaryTC].as_n_sample_4D(self.timesteps).astype(int)
        dummyData = dummyData.reshape((-1, self.timesteps, self.height, self.width, 1))
        testdata = np.row_stack((dummyData, sequenceArray))[-1000:]

        keys = [track[i+self.timesteps//2] for i in range(len(sequenceArray))]

        tsneArray = self.reduce_and_color(testData=testdata, nClusters=nClusters,
                                          primaryTC=primaryTC, aMode=aMode,
                                          cMode=cMode)[-len(sequenceArray):] + 1  # This is for color correction
        npmax = np.max(tsneArray[:, -1])+1
        figure = np.where(self.tc[primaryTC].map.imgArray > 0, npmax, 0)

        for i in range(len(sequenceArray)):
            figure[keys[i]] = tsneArray[i, -1]

        if multiPath:
            return keys, tsneArray[-1]
        else:
            self.print_n_show(figure, 'img', npmax, fileName=fileName, show=show)
            self.print_n_show(tsneArray, 'scatter', npmax, fileName=fileName, show=show)

        if completeSequence:
            if fileName:
                fileName = '%s_sequence.tif' % fileName[:fileName.find('.')]
            self._Trainer__colored_sequence(tsneArray, isoArray, maxVal=npmax, show=show, fileName=fileName)

    def __colored_sequence(self, tsneArray, isoArray, maxVal=0, show=True, fileName=''):
        """
        Returns all the Isovist sequences for a Track, next to its class color.
        :param tsneArray:
        :type tsneArray:
        :param isoArray:
        :type isoArray:
        :param maxVal:
        :type maxVal:
        :param show:
        :type show:
        :param fileName:
        :type fileName:
        :return:
        :rtype:
        """
        if not self.trained:
            raise EnvironmentError('Please train this Model first!')
        if maxVal == 0:
            maxVal, np.max(tsneArray[:, -1])

        spacing = 2
        figure = np.full(((spacing + len(tsneArray) * (self.height + spacing)), (self.timesteps + 1) * self.width), 2)

        # Iterate through a 4 by 4 grid with 100 spacing, to place the image
        for i in range(len(tsneArray)):
            backGround = np.full((self.height, self.width * (self.timesteps + 1)), tsneArray[i, -1])

            sequence = isoArray[i:i + self.timesteps].swapaxes(0, 1).reshape((
                self.height, self.timesteps * self.width))
            sequence = np.where(sequence > 0, maxVal, 0)
            backGround[:, 0:-self.width] = sequence
            figure[i * self.height + i * spacing: (i + 1)*self.height + i*spacing, :] = backGround
        if fileName:
            fileName = fileName if fileName.endswith('.tif') else '%s.tif' % fileName
        self.print_n_show(figure, 'img', maxVal, show=show, fileName=fileName)

    @staticmethod
    def print_n_show(x, mode, maxValue, show=True, fileName=''):
        # Scatterplot with Classes
        fig, ax = plt.subplots()
        # make the picture
        if mode == 'img' or mode == 'fig':
            pic = ax.imshow(x, cmap='gist_ncar', vmin=0, vmax=maxValue)
            cb = plt.colorbar(pic, spacing='proportional', ticks=np.linspace(0, maxValue, maxValue + 1))
        elif mode == 'rgba':
            pic = ax.imshow(x, cmap='gist_ncar', vmax=255)
            cb = plt.colorbar(pic, spacing='proportional', ticks=np.linspace(0, 255, maxValue + 1))
        elif mode == 'scatter':
            scat = ax.scatter(x[:, 0], x[:, 1], c=x[:, -1], )
            cb = plt.colorbar(scat, spacing='proportional', ticks=np.linspace(0, maxValue, maxValue + 1))
        elif mode == 'bars':
            objects = list(x.keys())
            y_pos = list(range(len(objects)))
            performance = [x[key] for key in objects]

            bar = ax.bar(y_pos, performance, align='center')
            # fig.xticks(y_pos, objects)
            # fig.ylabel('Usage')
            # fig.title('Programming language usage')
        else:
            raise ValueError('Mode needs to be "img", "fig", "bars", "rgba" or "scatter".')
        fig.tight_layout()
        if show:
            plt.show()
        if fileName:
            plt.savefig(fileName)

        return True

    def reduce_and_color(self, testData=None, aMode='tsne', nClusters=0, cMode='kmeans', eps=5, primaryTC=0):
        """
        :param  testData:   Numpy Arraym, shape (n, timesteps, height, width, 1)
        :type   testData:   sdf
        :param  eps:        When using cMode=DBSCAN,
        :type   eps:        int
        :param  aMode:      Dimensonal reduction mode, default = 'pca', other='tsne'
        :type   aMode:      str
        :param  cMode:      Clustering mode, default='kmeans', other='DBSCAN'
        :type   cMode:      str
        :param  nClusters:  Number of Clusters for kmeans-clustering if 0 nClusters=self.classes
        :type   nClusters:  int
        :param  primaryTC:  Index Number of the TrackCollection used for Basemap etc.
        :type   primaryTC:  int
        :return:            Numpy Array (n, X, Y, Labels)
        :rtype:             np.ndarray
        """
        if isinstance(testData, np.ndarray):
            if testData.shape[1:] != (self.timesteps, self.height, self.width, 1):
                raise ValueError('Shape must be (n, timesteps, height, width, 1), but was ', testData.shape)

        if not isinstance(testData, np.ndarray):
            testData = self.tc[primaryTC].as_n_sample_4D(self.timesteps)
            testData = testData.reshape((-1, self.timesteps, self.height, self.width, 1))

        n = testData.shape[0]

        C = np.zeros((n, self.cD * self.classes if self.mode == 'gumble' else self.classes))

        for i in range(0, n, 100):
            c = self.encoder([testData[i:i + 100]])[0]
            C[i:i + 100] = c.reshape(-1, self.cD * self.classes if self.mode == 'gumble' else self.classes)

        if aMode == 'tsne':
            array = TSNE(metric='hamming').fit_transform(C.reshape(n, -1))
        elif aMode == 'pca':
            array = PCA(n_components=self.classes).fit_transform(C.reshape(n, -1))
        elif aMode.lower() == 'none':
            array = C.reshape(n, -1)
        else:
            raise ValueError('"aMode" needs to be either "pca" or "tsne".')

        if cMode.lower() == 'dbscan':
            labels = DBSCAN(eps=eps, min_samples=10).fit_predict(array)

        elif cMode.lower() in ['kmeans', 'kmean', 'k-mean', 'k-means']:
            nClusters = nClusters if nClusters > 0 else self.classes
            labels = KMeans(n_clusters=nClusters).fit_predict(array)
        else:
            raise ValueError('"cMode" needs to be either "kmeans" or "dbscan".')

        if len(np.unique(labels)) > 1:
            color = labels
        else:
            # Color Generating
            X = array[:, 0] + np.min(array[:, 0])
            Y = (np.min(array[:, 1]) + array[:, 1]) // (np.max(array[:, 1]) + np.min(array[:, 1]))
            color = X * Y

        return np.column_stack((array, color))

    def viz_clusters(self, aMode='pca', cMode='kmeans', testdata=None, fileName=''):
            dataArray = self.reduce_and_color(aMode=aMode, cMode=cMode, testData=testdata)

            self.print_n_show(dataArray, 'scatter', np.max(dataArray[:, -1]), fileName=fileName)
            return True

    # THIS WORKS in all modes
    def show_prediction(self, n, dataArray=None, show=True, fileName='', startI=0, primaryTC=0):
        if not fileName and not show:
            raise ValueError('Why are you doing this? Print smth or show it!')
        if self.mode in ['gumble', 'vae', 'refined']:
            if not isinstance(dataArray, np.ndarray):
                dataArray = self.tc[primaryTC].as_n_sample_4D(self.timesteps)
            seqWidth = self.width * self.timesteps
            spacing = 1
            sqrtDim = int(sqrt(n)) + 1
            fullwidth = (seqWidth + spacing) * sqrtDim
            fullheight = (self.height*2 + spacing) * sqrtDim

            figure = np.zeros((fullheight, fullwidth))
            for i in range(n):

                array = dataArray[i+startI]
                arr_h = self.model.predict(array.reshape((1, self.timesteps, self.height, self.width, 1)))
                f = np.ones((self.height*2, seqWidth))
                f[:self.height, :seqWidth] = array.reshape(seqWidth, self.height).swapaxes(0, 1)
                f[self.height:self.height*2, : seqWidth] = arr_h.reshape(seqWidth, self.height).swapaxes(0, 1)

                try:
                    y, x = divmod(i, sqrtDim)
                except ZeroDivisionError:
                    x, y = 0, 0

                figure[y*self.height*2 + y*spacing: (y+1)*self.height*2 + y*spacing,
                       x*seqWidth + x*spacing: (x+1)*seqWidth + x*spacing] = f
            if fileName:
                fileName = fileName if fileName.endswith('.tif') else '%s.tif' % fileName
                self.print_n_show(figure, 'img', maxValue=np.max(figure), fileName=fileName)
            if show:
                self.print_n_show(figure, 'img', maxValue=np.max(figure))

            return True

    def sample_latent(self, nSamples, show=True, fileName=''):
        if self.mode not in ['gumble', 'vae', 'refined']:
            raise ValueError('Needs to be either of "gumble", "vae", "refined"')
        if self.classes >= self.height:
            raise NotImplementedError('This cannot be shown, edit the Funciton!')

        seqWidth = self.width * self.timesteps
        spacing = 1
        sqrtDim = int(sqrt(nSamples)) + 1
        fullwidth = (seqWidth + spacing) * sqrtDim
        if self.mode == 'gumble':
            if self.cD >= fullwidth:
                raise NotImplementedError('This cannot be shown, please edit the Function!!!')
            lHSpace = (self.height - self.classes) // 2
            lWSpace = (seqWidth - self.cD) // 2
            lShape = (self.classes, self.cD)

        else:
            if self.classes >= seqWidth:
                raise NotImplementedError('To many Samples, this cannot be displayed, please edit the Function!!!')
            lHSpace = (self.height - 1) // 2
            lWSpace = (seqWidth - self.classes) // 2
            lShape = (-1, self.classes)

        fullheight = (self.height*2 + spacing) * sqrtDim
        figure = np.zeros((fullheight, fullwidth))

        for i in range(nSamples):
            f = np.ones((self.height * 2, seqWidth))

            if self.mode == 'gumble':
                # https://stackoverflow.com/a/42874726/7746808
                oneHot = np.eye(self.classes)[np.random.randint(0, self.classes, self.cD)]
                sample = oneHot.reshape((-1, self.classes*self.cD))
                f[lHSpace:lHSpace + self.classes, lWSpace:lWSpace + self.cD] = sample.swapaxes(0, 1).reshape(lShape)
            else:
                if self.mode == 'vae':
                    sampleSpace = np.random.randn(nSamples, self.classes)
                elif self.mode == 'refined':
                    sampleSpace = np.random.rand(nSamples, self.classes)
                else:
                    raise ValueError('Needs to be either of "gumble", "vae", "refined"')

                sample = sampleSpace[i].reshape(lShape)
                f[lHSpace:lHSpace+1, lWSpace:lWSpace + self.classes] = sample

            arr_h = self.generator.predict(sample)
            f[self.height:self.height*2, : seqWidth] = arr_h.reshape(seqWidth, self.height).swapaxes(0, 1)

            try:
                y, x = divmod(i, sqrtDim)
            except ZeroDivisionError:
                x, y = 0, 0

            figure[y * self.height * 2 + y*spacing: (y + 1) * self.height * 2 + y*spacing,
                   x * seqWidth + x*spacing: (x + 1) * seqWidth + x*spacing] = f

        if fileName:
            fileName = fileName if fileName.endswith('.tif') else '%s.tif' % fileName
            self.print_n_show(figure, 'img', maxValue=np.max(figure), fileName=fileName)
        if show:
            self.print_n_show(figure, 'img', maxValue=np.max(figure))
        return True

    def multi_path_coloring(self, nClusters, fileName='', state='', primaryTC=0, uncertainty=False, rgba=False):
        if nClusters <= 2:
            raise ValueError('More than 2 Classes are needed')

        if fileName and state.lower() == 'load':
            import pickle
            with open(fileName, 'rb') as file:
                patchDict = pickle.load(file)
        else:
            patchDict = defaultdict(list)
            for key in self.tc[primaryTC].keys():

                tempKeys, tempSequence = self.tc[primaryTC].as_n_sample_4D(self.timesteps,
                                                                           in_walk_dir=True,
                                                                           keys=True,
                                                                           moving_window=True,
                                                                           for_track=key)

                C = np.zeros((len(tempSequence), self.cD * self.classes if self.mode == 'gumble' else self.classes))

                for i in range(0, len(tempSequence), 100):
                    c = self.encoder([tempSequence[i:i + 100]])[0]
                    C[i:i + 100] = c.reshape(-1, self.cD * self.classes if self.mode == 'gumble' else self.classes)

                for i, tempKey in enumerate(tempKeys):
                        patchDict[tempKey].append(list(C[i]))

        if fileName and state.lower() == 'dump':
            with open(fileName, 'wb') as f:
                import pickle
                pickle.dump(patchDict, f, pickle.HIGHEST_PROTOCOL)

        l = list()
        for x in patchDict.keys():
            for elem in patchDict[x]:
                l.append(elem + list(x))

        a = np.array(l)
        k = KMeans(nClusters).fit_predict(a[:, :-2]) + 1  # Color Correction

        s = np.zeros((a.shape[0], a.shape[1] + 1))
        s[:, :-1] = a
        s[:, -1] = k

        patchDict = defaultdict(list)
        for i in range(s.shape[0]):
            key = int(s[i][-3]), int(s[i][-2])
            patchDict[key].append(int(s[i][-1]))

        from collections import Counter
        c = Counter()
        for key in patchDict.keys():
            c[len(set(patchDict[key]))] += 1

        npmax = np.max(s[:, -1]) + 1
        # npmax = 4
        self.print_n_show(c, 'bars', npmax)

        if rgba:
            figure = Image.fromarray(np.where(self.tc[primaryTC].map.imgArray > 0, 255, 0)).convert('RGBA')
            draw = ImageDraw.Draw(figure, 'RGBA')
            from matplotlib import cm
            cmap = cm.get_cmap('gist_ncar', 12)  # 12 discrete colors

            for key in patchDict.keys():
                for value in patchDict[key]:
                    color = [int(x*255) for x in cmap(int(value), alpha=0.3)]
                    draw.point((key[1], key[0]), fill=tuple(color))

            self.print_n_show(figure, 'rgba', npmax)
        else:
            figure = np.where(self.tc[primaryTC].map.imgArray > 0, npmax, 0)
            for key in patchDict.keys():
                c = Counter(patchDict[key])
                figure[key] = c.most_common(1)[0][0]
            self.print_n_show(figure, 'img', npmax)

        if uncertainty:
            uncertainfig = np.where(self.tc[primaryTC].map.imgArray > 0, npmax + 1, 0)
            for key in patchDict.keys():
                uncertainfig[key] = len(set(patchDict[key])) * 2
            self.print_n_show(uncertainfig, 'img', npmax+1)
        return

    def show_silhouette_score(self, k_list, primaryTC=0):
        X = None
        for key in list(self.tc[primaryTC].keys())[:100]:
            tempSequence = self.tc[primaryTC].as_n_sample_4D(self.timesteps,
                                                             in_walk_dir=True,
                                                             keys=False,
                                                             moving_window=True,
                                                             for_track=key)
            C = np.zeros((len(tempSequence), self.cD * self.classes if self.mode == 'gumble' else self.classes))

            for i in range(0, len(tempSequence), 100):
                c = self.encoder([tempSequence[i:i + 100]])[0]
                C[i:i + 100] = c.reshape(-1, self.cD * self.classes if self.mode == 'gumble' else self.classes)
            if isinstance(X, np.ndarray):
                X = np.row_stack((X, C))
            else:
                X = C

        for n_clusters in k_list:
            # Create a subplot with 1 row and 2 columns
            fig, (ax1, ax2) = plt.subplots(1, 2)
            fig.set_size_inches(18, 7)

            # The 1st subplot is the silhouette plot
            # The silhouette coefficient can range from -1, 1 but in this example all
            # lie within [-0.1, 1]
            ax1.set_xlim([-0.1, 1])
            # The (n_clusters+1)*10 is for inserting blank space between silhouette
            # plots of individual clusters, to demarcate them clearly.
            ax1.set_ylim([0, len(X) + (n_clusters + 1) * 10])

            # Initialize the clusterer with n_clusters value and a random generator
            # seed of 10 for reproducibility.
            clusterer = KMeans(n_clusters=n_clusters, random_state=10)
            cluster_labels = clusterer.fit_predict(X)

            # The silhouette_score gives the average value for all the samples.
            # This gives a perspective into the density and separation of the formed
            # clusters
            silhouette_avg = silhouette_score(X, cluster_labels)
            print("For n_clusters =", n_clusters,
                  "The average silhouette_score is :", silhouette_avg)

            # Compute the silhouette scores for each sample
            sample_silhouette_values = silhouette_samples(X, cluster_labels)

            y_lower = 10
            for i in range(n_clusters):
                # Aggregate the silhouette scores for samples belonging to
                # cluster i, and sort them
                ith_cluster_silhouette_values = \
                    sample_silhouette_values[cluster_labels == i]

                ith_cluster_silhouette_values.sort()

                size_cluster_i = ith_cluster_silhouette_values.shape[0]
                y_upper = y_lower + size_cluster_i

                color = cm.spectral(float(i) / n_clusters)
                ax1.fill_betweenx(np.arange(y_lower, y_upper),
                                  0, ith_cluster_silhouette_values,
                                  facecolor=color, edgecolor=color, alpha=0.7)

                # Label the silhouette plots with their cluster numbers at the middle
                ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))

                # Compute the new y_lower for next plot
                y_lower = y_upper + 10  # 10 for the 0 samples

            ax1.set_title("The silhouette plot for the various clusters.")
            ax1.set_xlabel("The silhouette coefficient values")
            ax1.set_ylabel("Cluster label")

            # The vertical line for average silhouette score of all the values
            ax1.axvline(x=silhouette_avg, color="red", linestyle="--")

            ax1.set_yticks([])  # Clear the yaxis labels / ticks
            ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])

            # 2nd Plot showing the actual clusters formed
            colors = cm.spectral(cluster_labels.astype(float) / n_clusters)
            ax2.scatter(X[:, 0], X[:, 1], marker='.', s=30, lw=0, alpha=0.7,
                        c=colors)

            # Labeling the clusters
            centers = clusterer.cluster_centers_
            # Draw white circles at cluster centers
            ax2.scatter(centers[:, 0], centers[:, 1],
                        marker='o', c="white", alpha=1, s=200)

            for i, c in enumerate(centers):
                ax2.scatter(c[0], c[1], marker='$%d$' % i, alpha=1, s=50)

            ax2.set_title("The visualization of the clustered data.")
            ax2.set_xlabel("Feature space for the 1st feature")
            ax2.set_ylabel("Feature space for the 2nd feature")

            plt.suptitle(("Silhouette analysis for KMeans clustering on sample data "
                          "with n_clusters = %d" % n_clusters),
                         fontsize=14, fontweight='bold')

            plt.show()