118 lines
4.4 KiB
Python
118 lines
4.4 KiB
Python
from keras.layers import TimeDistributed, Dense, LSTM, UpSampling2D, RepeatVector, \
|
|
MaxPooling2D, Convolution2D, Deconvolution2D, Flatten, Reshape, Input
|
|
from keras.models import Sequential, Model
|
|
|
|
from keras import backend as K
|
|
|
|
import numpy as np
|
|
import pickle
|
|
from math import sqrt
|
|
|
|
from Trainer import Trainer
|
|
|
|
|
|
def get_batch(X, size):
|
|
a = np.random.choice(len(X), size, replace=False)
|
|
return X[a]
|
|
|
|
|
|
def load_preprocesseed_data(filename):
|
|
if not filename.endswith('.pik'):
|
|
raise TypeError('input File needs to be a Pickle object ".pik"!')
|
|
with open(filename, 'rb') as f:
|
|
data = pickle.load(f)
|
|
return data
|
|
|
|
|
|
# https://github.com/dribnet/plat/blob/master/plat/interpolate.py#L15
|
|
# https://pdfs.semanticscholar.org/f46c/307d4d73e86412e0c57161fb52f7591e124b.pdf
|
|
def slerp(val, low, high):
|
|
"""Spherical interpolation. val has a range of 0 to 1."""
|
|
if val <= 0:
|
|
return low
|
|
elif val >= 1:
|
|
return high
|
|
elif np.allclose(low, high):
|
|
return low
|
|
|
|
# noinspection PyTypeChecker
|
|
omega = np.arccos(round(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), 15))
|
|
so = np.sin(omega)
|
|
re = np.sin((1.0-val)*omega) / so * low + np.sin(val*omega)/so * high
|
|
return re * np.random.rand(low.shape[-1])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
'''HERE IS THE TRAINING!!!!!'''
|
|
trackCollection = load_preprocesseed_data('test_track.pik')
|
|
K.set_image_dim_ordering('tf')
|
|
T = Trainer('refined', trackCollection, 10, categorical_distribution=0, batchSize=400, filters=10)
|
|
|
|
# PreStage 1: Encoder Input
|
|
enc_input = Input(shape=(T.timesteps, T.width, T.height, 1), name='Main_Input')
|
|
|
|
enc = Sequential()
|
|
enc.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(3, 3), strides=1,
|
|
name='Conv1'), input_shape=(T.timesteps, T.width, T.height, 1)))
|
|
enc.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2, name='MaxPool1')))
|
|
|
|
enc.add(TimeDistributed(Convolution2D(activation='relu', filters=T.filters, kernel_size=(5, 5), strides=1,
|
|
name='Conv2')))
|
|
enc.add(TimeDistributed(MaxPooling2D(pool_size=2, strides=2, name='MaxPool2')))
|
|
|
|
enc.add(TimeDistributed(Flatten(name='Flatten')))
|
|
enc.add(LSTM(int(enc.layers[-1].output_shape[-1]), return_sequences=False, name='LSTM_Encode'))
|
|
|
|
encoding = enc(enc_input)
|
|
|
|
# Stage 2: Bottleneck
|
|
z = Dense(T.classes, activation='softmax', name='Clustering')(encoding)
|
|
|
|
#
|
|
dec = Sequential()
|
|
dec.add(RepeatVector(T.timesteps, name='TimeRepeater', input_shape=(T.classes,)))
|
|
dec.add(LSTM(enc.layers[-2].output_shape[-1], return_sequences=True, name='LSTM_Decode'))
|
|
|
|
reValue = int(sqrt(dec.layers[-1].output_shape[-1]//T.filters))
|
|
dec.add(TimeDistributed(Reshape((reValue, reValue, T.filters))))
|
|
|
|
dec.add(TimeDistributed(UpSampling2D(2, name='Up1')))
|
|
|
|
dec.add(TimeDistributed(Deconvolution2D(activation='relu', filters=T.filters//2, kernel_size=(4, 4), strides=1,
|
|
name='DeConv1')))
|
|
dec.add(TimeDistributed(UpSampling2D(2, name='Up2')))
|
|
|
|
dec.add(TimeDistributed(Deconvolution2D(activation='relu', filters=1, kernel_size=(5, 5), strides=1,
|
|
name='DeConv2')))
|
|
|
|
dec_output = dec(z)
|
|
|
|
T.set_model(Model(inputs=enc_input, outputs=dec_output), optimizer='adagrad', loss='binary_crossentropy')
|
|
|
|
decoder_input = Input(shape=(T.classes,))
|
|
decoded = dec(decoder_input)
|
|
T.set_generator(Model(inputs=decoder_input, outputs=decoded))
|
|
|
|
encoder = K.Function([enc_input], [z])
|
|
T.set_encoder(encoder)
|
|
|
|
if False:
|
|
T.train('refinedWeights')
|
|
T.save_weights('refinedWeights')
|
|
if True:
|
|
T.load_weights('refinedWeights')
|
|
if False:
|
|
T.plot_model('refined.png', show_shapes=True, show_layer_names=True)
|
|
if False:
|
|
# T.color_track(trackCollection[list(trackCollection.keys())[2200]]) # 2600
|
|
T.color_track(trackCollection[list(trackCollection.keys())[2200]], nClusters=4) # 2600
|
|
# T.color_random_track(completeSequence=True)
|
|
if False:
|
|
T.show_prediction(200)
|
|
if False:
|
|
T.sample_latent(200)
|
|
if False:
|
|
T.multi_path_coloring(10)
|
|
if True:
|
|
T.show_silhouette_score([120])
|