136 lines
5.4 KiB
Python
136 lines
5.4 KiB
Python
from keras.layers import (Input, TimeDistributed, Dense, LSTM, UpSampling2D, RepeatVector, MaxPooling2D,
|
|
Convolution2D, Deconvolution2D, Flatten, Reshape, Lambda)
|
|
|
|
from keras.models import Model, Sequential
|
|
|
|
from keras import backend as K
|
|
|
|
from keras.metrics import binary_crossentropy
|
|
from keras.activations import softmax
|
|
|
|
|
|
import numpy as np
|
|
import pickle
|
|
from math import sqrt
|
|
|
|
from Trainer import Trainer
|
|
|
|
|
|
def get_batch(X, size):
|
|
a = np.random.choice(len(X), size, replace=False)
|
|
return X[a]
|
|
|
|
|
|
def load_preprocesseed_data(filename):
|
|
if not filename.endswith('.pik'):
|
|
raise TypeError('input File needs to be a Pickle object ".pik"!')
|
|
with open(filename, 'rb') as f:
|
|
data = pickle.load(f)
|
|
return data
|
|
|
|
|
|
if __name__ == '__main__':
|
|
K.set_image_dim_ordering('tf')
|
|
'''HERE IS THE TRAINING!!!!!'''
|
|
# Paper From https://github.com/nzw0301/keras-examples/blob/master/gumbel_softmax_vae_MNIST.ipynb
|
|
# https://arxiv.org/pdf/1611.01144.pdf
|
|
|
|
# Data PreProcessing, keep the Batchsize Shmall because of Small memory 500 Should do, rerun the fitting!
|
|
trackCollection = load_preprocesseed_data('test_track.pik')
|
|
T = Trainer('gumble', trackCollection, 10, 30, filters=10)
|
|
|
|
# PreStage 1: Encoder Input
|
|
enc_input = Input(shape=(T.timesteps, T.width, T.height, 1), name='main_input')
|
|
|
|
# Stage 1: Encoding
|
|
enc_seq = Sequential(name='Encoder')
|
|
enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(3, 3), strides=1),
|
|
name='Conv1', input_shape=(T.timesteps, T.width, T.height, 1)))
|
|
enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool1'))
|
|
|
|
enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(5, 5), strides=1),
|
|
name='Conv2'))
|
|
enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool2'))
|
|
|
|
enc_seq.add(TimeDistributed(Flatten(), name='Flatten'))
|
|
enc_seq.add(LSTM(int(enc_seq.layers[-1].output_shape[-1]), return_sequences=False, name='LSTM_Encode'))
|
|
|
|
encoding = enc_seq(enc_input)
|
|
|
|
# Stage 2: Bottleneck
|
|
logits_y = Dense(T.classes * T.cD)(encoding) # activation='softmax' ICh denke nicht
|
|
|
|
# Sampling Function
|
|
def sampling(logits):
|
|
U = K.random_uniform(K.shape(logits), 0, 1)
|
|
y = logits - K.log(-K.log(U + 1e-20) + 1e-20) # logits + gumbel noise
|
|
y = softmax(K.reshape(y, (-1, T.cD, T.classes)) / T.tau)
|
|
y = K.reshape(y, (-1, T.cD * T.classes))
|
|
return y
|
|
|
|
z = Lambda(sampling,)(logits_y)
|
|
|
|
# Stage 3: Decoding
|
|
dec_seq = Sequential(name='Decoder')
|
|
|
|
dec_seq.add(RepeatVector(T.timesteps, name='TimeRepeater', input_shape=(T.classes * T.cD,)))
|
|
dec_seq.add(LSTM(enc_seq.layers[-1].output_shape[-1], return_sequences=True, name='LSTM_Decode'))
|
|
|
|
reValue = int(sqrt(dec_seq.layers[-1].output_shape[-1]//T.filters))
|
|
dec_seq.add(TimeDistributed(Reshape((reValue, reValue, T.filters)), name='ReShape'))
|
|
|
|
dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up1'))
|
|
dec_seq.add(TimeDistributed(Deconvolution2D(activation='relu', filters=T.filters//2, kernel_size=(4, 4), strides=1),
|
|
name='DeConv1'))
|
|
dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up2'))
|
|
dec_seq.add(TimeDistributed(Deconvolution2D(activation='hard_sigmoid', filters=1, kernel_size=(5, 5), strides=1),
|
|
name='DeConv2'))
|
|
|
|
dec_output = dec_seq(z)
|
|
|
|
# Gumble Loss Function
|
|
def gumbel_loss(x, x_hat):
|
|
# N = T.cD; M = T.classes
|
|
q_y = K.reshape(logits_y, (-1, T.cD, T.classes))
|
|
q_y = softmax(q_y)
|
|
log_q_y = K.log(q_y + 1e-20)
|
|
kl_tmp = q_y * (log_q_y - K.log(1.0 / T.classes))
|
|
KL = K.sum(kl_tmp, axis=(1, 2))
|
|
x = K.reshape(x, (-1, T.original_dim)) # !
|
|
x_hat = K.reshape(x_hat, (-1, T.original_dim)) # !
|
|
elbo = T.original_dim * binary_crossentropy(x, x_hat) - KL
|
|
return elbo
|
|
|
|
T.set_model(Model(inputs=enc_input, outputs=dec_output), gumbel_loss, optimizer='adagrad')
|
|
|
|
# Generatorfrom latent to input space
|
|
decoder_input = Input(shape=(T.classes * T.cD,))
|
|
decoder_output = dec_seq(decoder_input)
|
|
T.set_generator(Model(inputs=decoder_input, outputs=decoder_output))
|
|
|
|
# Separate encoder from input to latent space
|
|
argmax_y = K.max(K.reshape(logits_y, (-1, T.cD, T.classes)), axis=-1, keepdims=True)
|
|
argmax_y = K.equal(K.reshape(logits_y, (-1, T.cD, T.classes)), argmax_y)
|
|
encoder = K.function([enc_input], [argmax_y])
|
|
T.set_encoder(encoder)
|
|
|
|
if False:
|
|
T.train('GumbleLSTMWeights')
|
|
T.save_weights('GumbleLSTMWeights')
|
|
if True:
|
|
T.load_weights('GumbleWeights')
|
|
if False:
|
|
T.plot_model('GumbleLSTM.png', show_shapes=True, show_layer_names=True)
|
|
if False:
|
|
# T.color_track(trackCollection[list(trackCollection.keys())[2200]], nClusters=4) # 2600
|
|
T.color_random_track(completeSequence=False, nClusters=4)
|
|
if False:
|
|
T.show_prediction(200, startI=200)
|
|
if False:
|
|
T.sample_latent(200)
|
|
if False:
|
|
T.multi_path_coloring(10)
|
|
if True:
|
|
T.show_silhouette_score([2,4,6,8,10,12,14])
|
|
|