2017-07-22 21:13:27 +02:00

118 lines
4.4 KiB
Python

from keras.layers import TimeDistributed, Dense, LSTM, UpSampling2D, RepeatVector, \
MaxPooling2D, Convolution2D, Deconvolution2D, Flatten, Reshape, Input
from keras.models import Sequential, Model
from keras import backend as K
import numpy as np
import pickle
from math import sqrt
from Trainer import Trainer
def get_batch(X, size):
a = np.random.choice(len(X), size, replace=False)
return X[a]
def load_preprocesseed_data(filename):
if not filename.endswith('.pik'):
raise TypeError('input File needs to be a Pickle object ".pik"!')
with open(filename, 'rb') as f:
data = pickle.load(f)
return data
# https://github.com/dribnet/plat/blob/master/plat/interpolate.py#L15
# https://pdfs.semanticscholar.org/f46c/307d4d73e86412e0c57161fb52f7591e124b.pdf
def slerp(val, low, high):
"""Spherical interpolation. val has a range of 0 to 1."""
if val <= 0:
return low
elif val >= 1:
return high
elif np.allclose(low, high):
return low
# noinspection PyTypeChecker
omega = np.arccos(round(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), 15))
so = np.sin(omega)
re = np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high
return re * np.random.rand(low.shape[-1])
if __name__ == '__main__':
'''HERE IS THE TRAINING!!!!!'''
trackCollection = load_preprocesseed_data('test_track.pik')
K.set_image_dim_ordering('tf')
T = Trainer('refined', trackCollection, 10, categorical_distribution=0, batchSize=400, filters=10)
# PreStage 1: Encoder Input
enc_input = Input(shape=(T.timesteps, T.width, T.height, 1), name='Main_Input')
enc = Sequential()
enc.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(3, 3), strides=1,
name='Conv1'), input_shape=(T.timesteps, T.width, T.height, 1)))
enc.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2, name='MaxPool1')))
enc.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(5, 5), strides=1,
name='Conv2')))
enc.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2, name='MaxPool2')))
enc.add(TimeDistributed(Flatten(name='Flatten')))
enc.add(LSTM(int(enc.layers[-1].output_shape[-1]), return_sequences=False, name='LSTM_Encode'))
encoding = enc(enc_input)
# Stage 2: Bottleneck
z = Dense(T.classes, activation='softmax', name='Clustering')(encoding)
#
dec = Sequential()
dec.add(RepeatVector(T.timesteps, name='TimeRepeater', input_shape=(T.classes,)))
dec.add(LSTM(enc.layers[-2].output_shape[-1], return_sequences=True, name='LSTM_Decode'))
reValue = int(sqrt(dec.layers[-1].output_shape[-1]//T.filters))
dec.add(TimeDistributed(Reshape((reValue, reValue, T.filters))))
dec.add(TimeDistributed(UpSampling2D(2, name='Up1')))
dec.add(TimeDistributed(Deconvolution2D(activation='relu', filters=T.filters//2, kernel_size=(4, 4), strides=1,
name='DeConv1')))
dec.add(TimeDistributed(UpSampling2D(2, name='Up2')))
dec.add(TimeDistributed(Deconvolution2D(activation='relu', filters=1, kernel_size=(5, 5), strides=1,
name='DeConv2')))
dec_output = dec(z)
T.set_model(Model(inputs=enc_input, outputs=dec_output), optimizer='adagrad', loss='binary_crossentropy')
decoder_input = Input(shape=(T.classes,))
decoded = dec(decoder_input)
T.set_generator(Model(inputs=decoder_input, outputs=decoded))
encoder = K.Function([enc_input], [z])
T.set_encoder(encoder)
if False:
T.train('refinedWeights')
T.save_weights('refinedWeights')
if True:
T.load_weights('refinedWeights')
if False:
T.plot_model('refined.png', show_shapes=True, show_layer_names=True)
if False:
# T.color_track(trackCollection[list(trackCollection.keys())[2200]]) # 2600
T.color_track(trackCollection[list(trackCollection.keys())[2200]], nClusters=4) # 2600
# T.color_random_track(completeSequence=True)
if False:
T.show_prediction(200)
if False:
T.sample_latent(200)
if False:
T.multi_path_coloring(10)
if True:
T.show_silhouette_score([120])