from keras.layers import (Input, TimeDistributed, Dense, GRU, UpSampling2D, RepeatVector, MaxPooling2D,
                          Convolution2D, Deconvolution2D, Flatten, Reshape, Lambda)

from keras.models import Model, Sequential
from keras import backend as K
from keras.metrics import binary_crossentropy


import numpy as np
import pickle
from math import sqrt

from Trainer import Trainer


def get_batch(X, size):
    a = np.random.choice(len(X), size, replace=False)
    return X[a]


def load_preprocesseed_data(filename):
    """

    :param filename: The Filename of the pikled Dataset to lod
    :type filename: str
    :return: Trackcollection
    :rtype: tools.TrackCollection
    """
    if not filename.endswith('.pik'):
        raise TypeError('input File needs to be a Pickle object ".pik"!')
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data


# https://github.com/dribnet/plat/blob/master/plat/interpolate.py#L15
# https://pdfs.semanticscholar.org/f46c/307d4d73e86412e0c57161fb52f7591e124b.pdf
def slerp(val, low, high):
    """Spherical interpolation. val has a range of 0 to 1."""
    if val <= 0:
        return low
    elif val >= 1:
        return high
    elif np.allclose(low, high):
        return low
    # noinspection PyTypeChecker
    omega = np.arccos(round(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), 15))
    so = np.sin(omega)
    re = np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high
    return re * np.random.rand(low.shape[-1])


if __name__ == '__main__':
    K.set_image_dim_ordering('tf')

    # Data PreProcessing
    tate = load_preprocesseed_data('Tate_1000.pik')
    maze = load_preprocesseed_data('Maze_1000.pik')
    oet = load_preprocesseed_data('Oet_1000.pik')
    tum = load_preprocesseed_data('Tum_1000.pik')
    doom = load_preprocesseed_data('Doom_1000.pik')
    priz = load_preprocesseed_data('Priz_1000.pik')
    cross = load_preprocesseed_data('crossing.pik')

    '''HERE IS THE TRAINING!!!!!'''
    # T = Trainer('vae', [maze, tate, oet, tum, doom, priz], 10, categorical_distribution=0,
    T = Trainer('vae', [tate, maze, oet, tum, doom, priz], 10, categorical_distribution=0,
                batchSize=1, timesteps=9, filters=0, rotating=True)  # BatchSize: 1600 for Training / 1 for Testing

    # PreStage 1: Encoder Input
    enc_input = Input(shape=(T.timesteps, T.width, T.height, 1), name='Main_Input')

    # Stage 1: Encoding
    enc_seq = Sequential(name='Encoder')
    enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(3, 3), strides=1),
                                name='Conv1', input_shape=(T.timesteps, T.width, T.height, 1)))
    enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool1'))

    enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(5, 5), strides=1),
                                name='Conv2'))
    enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool2'))

    enc_seq.add(TimeDistributed(Flatten(), name='Flatten'))
    enc_seq.add(GRU(int(enc_seq.layers[-1].output_shape[-1]), return_sequences=False, name='GRU_Encode'))

    encoding = enc_seq(enc_input)

    # Stage 2: Bottleneck
    out_z_mean = Dense(T.classes, name='Dense_Mean')(encoding)
    out_z_log_var = Dense(T.classes, name='Std_Dev')(encoding)


    def sampling(args, batch_size=500, classes=3, epsilon_std=0.01):
        z_mean, z_log_var = args
        epsilon = K.random_normal(shape=(batch_size, classes), mean=0., stddev=epsilon_std)
        return z_mean + K.exp(z_log_var / 2) * epsilon


    z = Lambda(sampling, name='Sampling', arguments={'batch_size': T.batchSize,
                                                     'classes': T.classes,
                                                     'epsilon_std': 0.01})([out_z_mean, out_z_log_var])

    # Stage 3: Decoding
    dec_seq = Sequential(name='Decoder')

    dec_seq.add(RepeatVector(T.timesteps, name='TimeRepeater', input_shape=(T.classes,)))
    dec_seq.add(GRU(enc_seq.layers[-1].output_shape[-1], return_sequences=True, name='GRU_Decode'))

    reValue = int(sqrt(dec_seq.layers[-1].output_shape[-1]//T.filters))
    dec_seq.add(TimeDistributed(Reshape((reValue, reValue, T.filters)), name='ReShape'))

    dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up1'))
    dec_seq.add(TimeDistributed(Deconvolution2D(activation='relu', filters=T.filters, kernel_size=(4, 4), strides=1),
                                name='DeConv1'))
    dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up2'))
    dec_seq.add(TimeDistributed(Deconvolution2D(activation='relu', filters=1, kernel_size=(5, 5), strides=1),
                                name='DeConv2'))

    dec_output = dec_seq(z)

    # Loss function minimized by autoencoder
    def vae_objective(true, pred):
        true = K.reshape(true, (-1, T.original_dim))          # !
        pred = K.reshape(pred, (-1, T.original_dim))          # !
        loss = binary_crossentropy(true, pred)
        kl_regu = -.5 * K.sum(1. + out_z_log_var - K.square(
            out_z_mean) - K.exp(out_z_log_var), axis=-1)
        return loss + kl_regu

    # Model
    T.set_model(Model(inputs=enc_input, outputs=dec_output), vae_objective)

    # Separate encoder from input to latent space
    encoder = K.function([enc_input], [out_z_mean])
    T.set_encoder(encoder)

    # Generatorfrom latent to input space
    decoder_input = Input(shape=(T.classes,))
    decoder_output = dec_seq(decoder_input)
    T.set_generator(Model(inputs=decoder_input, outputs=decoder_output))

    if False:
        # T.load_weights('VAEall_1k_9t_GRU')
        T.train('VAEall_1k_9t_GRU')
        T.save_weights('VAEall_1k_9t_GRU')
    if True:
        T.load_weights('VAEall_1k_9t_GRU')
        if False:
            T.plot_model('vae.png', show_shapes=True, show_layer_names=True)
        if False:
            T.color_track(cross[list(cross.keys())[0]], completeSequence=True,
                          nClusters=10, cMode='kmeans', aMode='None')  # 2600 for the 10k dataset
            # T.color_random_track(completeSequence=False, nClusters=10, cMode='kmeans', aMode='None')
        if False:
            T.show_prediction(200)
        if False:
            T.sample_latent(200)
        if True:
            T.multi_path_coloring(10, fileName='all', state='dump', primaryTC=-2, uncertainty=True)