This commit is contained in:
Si11ium
2019-03-07 12:05:58 +01:00
parent bae997feab
commit 95c2ff4200
4 changed files with 128 additions and 48 deletions

View File

@ -3,7 +3,9 @@ import copy
import numpy as np
from keras.models import Sequential
from keras.callbacks import Callback
from keras.layers import SimpleRNN, Dense
import keras.backend as K
from util import *
from experiment import *
@ -12,6 +14,20 @@ from experiment import *
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
class SaveStateCallback(Callback):
def __init__(self, net, epoch=0):
super(SaveStateCallback, self).__init__()
self.net = net
self.init_epoch = epoch
def on_epoch_end(self, epoch, logs={}):
description = dict(time=epoch+self.init_epoch)
description['action'] = 'train_self'
description['counterpart'] = None
self.net.save_state(**description)
return
class NeuralNetwork(PrintingObject):
@staticmethod
@ -64,6 +80,7 @@ class NeuralNetwork(PrintingObject):
self.params = dict(epsilon=0.00000000000001)
self.params.update(params)
self.keras_params = dict(activation='linear', use_bias=False)
self.states = []
def get_model(self):
raise NotImplementedError
@ -147,6 +164,23 @@ class NeuralNetwork(PrintingObject):
def print_weights(self, weights=None):
print(self.repr_weights(weights))
def make_state(self, **kwargs):
weights = self.get_weights_flat()
state = {'class': self.__class__.__name__, 'weights': weights}
if any(np.isinf(weights)):
return None
state.update(kwargs)
return state
def save_state(self, **kwargs):
state = self.make_state(**kwargs)
if state is not None:
self.states += [state]
else:
pass
def get_states(self):
return self.states
class WeightwiseNeuralNetwork(NeuralNetwork):
@ -600,10 +634,11 @@ class TrainingNeuralNetworkDecorator():
self.model_compiled = True
return self
def train(self, batchsize=1):
def train(self, batchsize=1, store_states=True, epoch=0):
self.compiled()
x, y = self.net.compute_samples()
history = self.net.model.fit(x=x, y=y, verbose=0, batch_size=batchsize)
savestatecallback = SaveStateCallback(net=self.net, epoch=epoch) if store_states else None
history = self.net.model.fit(x=x, y=y, verbose=0, batch_size=batchsize, callbacks=[savestatecallback])
return history.history['loss'][-1]
def train_other(self, other_network, batchsize=1):
@ -611,6 +646,7 @@ class TrainingNeuralNetworkDecorator():
other_network.compiled()
x, y = other_network.net.compute_samples()
history = self.net.model.fit(x=x, y=y, verbose=0, batch_size=batchsize)
return history.history['loss'][-1]
@ -648,17 +684,21 @@ if __name__ == '__main__':
if True:
# ok so this works quite realiably
with FixpointExperiment() as exp:
run_count = 1000
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2))\
.with_params(epsilon=0.0001).with_keras_params(optimizer='sgd')
for run_id in tqdm(range(run_count+1)):
loss = net.compiled().train()
if run_id % 100 == 0:
net.print_weights()
# print(net.apply_to_network(net))
print("Fixpoint? " + str(net.is_fixpoint()))
print("Loss " + str(loss))
print()
for i in range(10):
run_count = 1000
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2))\
.with_params(epsilon=0.0001).with_keras_params(optimizer='sgd')
for run_id in tqdm(range(run_count+1)):
loss = net.compiled().train(epoch=run_id)
if run_id % 100 == 0:
net.print_weights()
# print(net.apply_to_network(net))
print("Fixpoint? " + str(net.is_fixpoint()))
print("Loss " + str(loss))
print()
exp.historical_particles[i] = net
K.clear_session()
if False:
# this does not work as the aggregation function screws over the fixpoint computation....
# TODO: check for fixpoint in aggregated space...