147 lines
5.8 KiB
Python
147 lines
5.8 KiB
Python
from keras.layers import (Input, TimeDistributed, Dense, LSTM, UpSampling2D, RepeatVector, MaxPooling2D,
|
|
Convolution2D, Deconvolution2D, Flatten, Reshape, Lambda)
|
|
|
|
from keras.models import Model, Sequential
|
|
from keras import backend as K
|
|
from keras.metrics import binary_crossentropy
|
|
|
|
|
|
import numpy as np
|
|
import pickle
|
|
from math import sqrt
|
|
|
|
from Trainer import Trainer
|
|
|
|
|
|
def get_batch(X, size):
|
|
a = np.random.choice(len(X), size, replace=False)
|
|
return X[a]
|
|
|
|
|
|
def load_preprocesseed_data(filename):
|
|
if not filename.endswith('.pik'):
|
|
raise TypeError('input File needs to be a Pickle object ".pik"!')
|
|
with open(filename, 'rb') as f:
|
|
data = pickle.load(f)
|
|
return data
|
|
|
|
|
|
# https://github.com/dribnet/plat/blob/master/plat/interpolate.py#L15
|
|
# https://pdfs.semanticscholar.org/f46c/307d4d73e86412e0c57161fb52f7591e124b.pdf
|
|
def slerp(val, low, high):
|
|
"""Spherical interpolation. val has a range of 0 to 1."""
|
|
if val <= 0:
|
|
return low
|
|
elif val >= 1:
|
|
return high
|
|
elif np.allclose(low, high):
|
|
return low
|
|
# noinspection PyTypeChecker
|
|
omega = np.arccos(round(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), 15))
|
|
so = np.sin(omega)
|
|
re = np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high
|
|
return re * np.random.rand(low.shape[-1])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
K.set_image_dim_ordering('tf')
|
|
|
|
# Data PreProcessing
|
|
trackCollection = load_preprocesseed_data('test_track.pik') # Tate_10000
|
|
|
|
'''HERE IS THE TRAINING!!!!!'''
|
|
T = Trainer('vae', trackCollection, 10, categorical_distribution=0,
|
|
batchSize=1, timesteps=5, filters=10) # BatchSize: 400 for Training / 1 for Testing
|
|
|
|
# PreStage 1: Encoder Input
|
|
enc_input = Input(shape=(T.timesteps, T.width, T.height, 1), name='Main_Input')
|
|
|
|
# Stage 1: Encoding
|
|
enc_seq = Sequential(name='Encoder')
|
|
enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(3, 3), strides=1),
|
|
name='Conv1', input_shape=(T.timesteps, T.width, T.height, 1)))
|
|
enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool1'))
|
|
|
|
enc_seq.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(5, 5), strides=1),
|
|
name='Conv2'))
|
|
enc_seq.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2), name='MaxPool2'))
|
|
|
|
enc_seq.add(TimeDistributed(Flatten(), name='Flatten'))
|
|
enc_seq.add(LSTM(int(enc_seq.layers[-1].output_shape[-1]), return_sequences=False, name='LSTM_Encode'))
|
|
|
|
encoding = enc_seq(enc_input)
|
|
|
|
# Stage 2: Bottleneck
|
|
out_z_mean = Dense(T.classes, name='Dense_Mean')(encoding)
|
|
out_z_log_var = Dense(T.classes, name='Std_Dev')(encoding)
|
|
|
|
|
|
def sampling(args, batch_size=500, classes=3, epsilon_std=0.01):
|
|
z_mean, z_log_var = args
|
|
epsilon = K.random_normal(shape=(batch_size, classes), mean=0., stddev=epsilon_std)
|
|
return z_mean + K.exp(z_log_var / 2) * epsilon
|
|
|
|
|
|
z = Lambda(sampling, name='Sampling', arguments={'batch_size': T.batchSize,
|
|
'classes': T.classes,
|
|
'epsilon_std': 0.01})([out_z_mean, out_z_log_var])
|
|
|
|
# Stage 3: Decoding
|
|
dec_seq = Sequential(name='Decoder')
|
|
|
|
dec_seq.add(RepeatVector(T.timesteps, name='TimeRepeater', input_shape=(T.classes,)))
|
|
dec_seq.add(LSTM(enc_seq.layers[-1].output_shape[-1], return_sequences=True, name='LSTM_Decode'))
|
|
|
|
reValue = int(sqrt(dec_seq.layers[-1].output_shape[-1]//T.filters))
|
|
dec_seq.add(TimeDistributed(Reshape((reValue, reValue, T.filters)), name='ReShape'))
|
|
|
|
dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up1'))
|
|
dec_seq.add(TimeDistributed(Deconvolution2D(activation='relu', filters=T.filters//2, kernel_size=(4, 4), strides=1),
|
|
name='DeConv1'))
|
|
dec_seq.add(TimeDistributed(UpSampling2D(2), name='Up2'))
|
|
dec_seq.add(TimeDistributed(Deconvolution2D(activation='relu', filters=1, kernel_size=(5, 5), strides=1),
|
|
name='DeConv2'))
|
|
|
|
dec_output = dec_seq(z)
|
|
|
|
# Loss function minimized by autoencoder
|
|
def vae_objective(true, pred):
|
|
true = K.reshape(true, (-1, T.original_dim)) # !
|
|
pred = K.reshape(pred, (-1, T.original_dim)) # !
|
|
loss = binary_crossentropy(true, pred)
|
|
kl_regu = -.5 * K.sum(1. + out_z_log_var - K.square(
|
|
out_z_mean) - K.exp(out_z_log_var), axis=-1)
|
|
return loss + kl_regu
|
|
|
|
# Model
|
|
T.set_model(Model(inputs=enc_input, outputs=dec_output), vae_objective)
|
|
|
|
# Separate encoder from input to latent space
|
|
encoder = K.function([enc_input], [out_z_mean])
|
|
T.set_encoder(encoder)
|
|
|
|
# Generatorfrom latent to input space
|
|
decoder_input = Input(shape=(T.classes,))
|
|
decoder_output = dec_seq(decoder_input)
|
|
T.set_generator(Model(inputs=decoder_input, outputs=decoder_output))
|
|
|
|
if False:
|
|
T.train('VAEweights')
|
|
T.save_weights('VAEweights')
|
|
if True:
|
|
T.load_weights('VAEweights')
|
|
if False:
|
|
T.plot_model('vae.png', show_shapes=True, show_layer_names=True)
|
|
if False:
|
|
T.color_track(trackCollection[list(trackCollection.keys())[2200]],
|
|
nClusters=10, cMode='kmeans', aMode='None') # 2600
|
|
# T.color_random_track(completeSequence=False, nClusters=10, cMode='kmeans', aMode='None')
|
|
if False:
|
|
T.show_prediction(200)
|
|
if False:
|
|
T.sample_latent(200)
|
|
if False:
|
|
T.multi_path_coloring(10, fileName='', state='')
|
|
if True:
|
|
T.show_silhouette_score([120])
|