from keras.layers import (Input, TimeDistributed, Dense, LSTM, UpSampling2D, RepeatVector, MaxPooling2D, Convolution2D, Deconvolution2D, Flatten, Reshape, Lambda) from keras.models import Model, Sequential from keras import backend as K from keras.metrics import binary_crossentropy from keras.activations import softmax import numpy as np import pickle from math import sqrt from Trainer import Trainer def get_batch(X, size): a = np.random.choice(len(X), size, replace=False) return X[a] def load_preprocesseed_data(filename): if not filename.endswith('.pik'): raise TypeError('input File needs to be a Pickle object ".pik"!') with open(filename, 'rb') as f: data = pickle.load(f) return data if __name__ == '__main__': K.set_image_dim_ordering('tf') '''HERE IS THE TRAINING!!!!!''' # Paper From https://github.com/nzw0301/keras-examples/blob/master/gumbel_softmax_vae_MNIST.ipynb # https://arxiv.org/pdf/1611.01144.pdf # Data PreProcessing, keep the Batchsize Shmall because of Small memory 500 Should do, rerun the fitting! trackCollection = load_preprocesseed_data('test_track.pik') T = Trainer('gumble', trackCollection, 2, 5) # PreStage 1: Encoder Input enc_input = Input(shape=(T.timesteps, T.width, T.height, 1), name='main_input') # Stage 1: Encoding enc_seq = Sequential(name='Encoder') enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(3, 3), strides=1), name='Conv1', input_shape=(T.timesteps, T.width, T.height, 1))) enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool1')) enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(5, 5), strides=1), name='Conv2')) enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool2')) enc_seq.add(TimeDistributed(Flatten(), name='Flatten')) enc_seq.add(LSTM(int(enc_seq.layers[-1].output_shape[-1]), return_sequences=False, name='LSTM_Encode')) encoding = enc_seq(enc_input) # Stage 2: Bottleneck logits_y = Dense(T.classes * T.cD)(encoding) # activation='softmax' ICh denke nicht # Sampling Function def sampling(logits): U = K.random_uniform(K.shape(logits), 0, 1) y = logits - K.log(-K.log(U + 1e-20) + 1e-20) # logits + gumbel noise y = softmax(K.reshape(y, (-1, T.cD, T.classes)) / T.tau) y = K.reshape(y, (-1, T.cD * T.classes)) return y z = Lambda(sampling,)(logits_y) # Stage 3: Decoding dec_seq = Sequential(name='Decoder') dec_seq.add(RepeatVector(T.timesteps, name='TimeRepeater', input_shape=(T.classes * T.cD,))) dec_seq.add(LSTM(enc_seq.layers[-1].output_shape[-1], return_sequences=True, name='LSTM_Decode')) reValue = int(sqrt(dec_seq.layers[-1].output_shape[-1]//T.filters)) dec_seq.add(TimeDistributed(Reshape((reValue, reValue, T.filters)), name='ReShape')) dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up1')) dec_seq.add(TimeDistributed(Deconvolution2D(activation='relu', filters=T.filters, kernel_size=(4, 4), strides=1), name='DeConv1')) dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up2')) dec_seq.add(TimeDistributed(Deconvolution2D(activation='hard_sigmoid', filters=1, kernel_size=(5, 5), strides=1), name='DeConv2')) dec_output = dec_seq(z) # Gumble Loss Function def gumbel_loss(x, x_hat): # N = T.cD; M = T.classes q_y = K.reshape(logits_y, (-1, T.cD, T.classes)) q_y = softmax(q_y) log_q_y = K.log(q_y + 1e-20) kl_tmp = q_y * (log_q_y - K.log(1.0 / T.classes)) KL = K.sum(kl_tmp, axis=(1, 2)) x = K.reshape(x, (-1, T.original_dim)) # ! x_hat = K.reshape(x_hat, (-1, T.original_dim)) # ! elbo = T.original_dim * binary_crossentropy(x, x_hat) - KL return elbo T.set_model(Model(inputs=enc_input, outputs=dec_output), gumbel_loss, optimizer='adagrad') # Generatorfrom latent to input space decoder_input = Input(shape=(T.classes * T.cD,)) decoder_output = dec_seq(decoder_input) T.set_generator(Model(inputs=decoder_input, outputs=decoder_output)) # Separate encoder from input to latent space argmax_y = K.max(K.reshape(logits_y, (-1, T.cD, T.classes)), axis=-1, keepdims=True) argmax_y = K.equal(K.reshape(logits_y, (-1, T.cD, T.classes)), argmax_y) encoder = K.function([enc_input], [argmax_y]) T.set_encoder(encoder) if True: T.load_weights('Gumble10Weights') T.train('Gumble10Weights') T.save_weights('Gumble10Weights') if False: T.load_weights('Gumble10Weights') if False: T.plot_model('Gumble10.png', show_shapes=True, show_layer_names=True) if False: # T.color_track(trackCollection[list(trackCollection.keys())[2200]], nClusters=4) # 2600 T.color_random_track(completeSequence=False, nClusters=4) if True: T.show_prediction(200) if False: T.sample_latent(200) if True: T.multi_path_coloring(10)