Visuals
This commit is contained in:
@ -3,7 +3,9 @@ import copy
|
||||
import numpy as np
|
||||
|
||||
from keras.models import Sequential
|
||||
from keras.callbacks import Callback
|
||||
from keras.layers import SimpleRNN, Dense
|
||||
import keras.backend as K
|
||||
|
||||
from util import *
|
||||
from experiment import *
|
||||
@ -12,6 +14,20 @@ from experiment import *
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
|
||||
class SaveStateCallback(Callback):
|
||||
def __init__(self, net, epoch=0):
|
||||
super(SaveStateCallback, self).__init__()
|
||||
self.net = net
|
||||
self.init_epoch = epoch
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
description = dict(time=epoch+self.init_epoch)
|
||||
description['action'] = 'train_self'
|
||||
description['counterpart'] = None
|
||||
self.net.save_state(**description)
|
||||
return
|
||||
|
||||
|
||||
class NeuralNetwork(PrintingObject):
|
||||
|
||||
@staticmethod
|
||||
@ -64,6 +80,7 @@ class NeuralNetwork(PrintingObject):
|
||||
self.params = dict(epsilon=0.00000000000001)
|
||||
self.params.update(params)
|
||||
self.keras_params = dict(activation='linear', use_bias=False)
|
||||
self.states = []
|
||||
|
||||
def get_model(self):
|
||||
raise NotImplementedError
|
||||
@ -147,6 +164,23 @@ class NeuralNetwork(PrintingObject):
|
||||
def print_weights(self, weights=None):
|
||||
print(self.repr_weights(weights))
|
||||
|
||||
def make_state(self, **kwargs):
|
||||
weights = self.get_weights_flat()
|
||||
state = {'class': self.__class__.__name__, 'weights': weights}
|
||||
if any(np.isinf(weights)):
|
||||
return None
|
||||
state.update(kwargs)
|
||||
return state
|
||||
|
||||
def save_state(self, **kwargs):
|
||||
state = self.make_state(**kwargs)
|
||||
if state is not None:
|
||||
self.states += [state]
|
||||
else:
|
||||
pass
|
||||
def get_states(self):
|
||||
return self.states
|
||||
|
||||
|
||||
class WeightwiseNeuralNetwork(NeuralNetwork):
|
||||
|
||||
@ -600,10 +634,11 @@ class TrainingNeuralNetworkDecorator():
|
||||
self.model_compiled = True
|
||||
return self
|
||||
|
||||
def train(self, batchsize=1):
|
||||
def train(self, batchsize=1, store_states=True, epoch=0):
|
||||
self.compiled()
|
||||
x, y = self.net.compute_samples()
|
||||
history = self.net.model.fit(x=x, y=y, verbose=0, batch_size=batchsize)
|
||||
savestatecallback = SaveStateCallback(net=self.net, epoch=epoch) if store_states else None
|
||||
history = self.net.model.fit(x=x, y=y, verbose=0, batch_size=batchsize, callbacks=[savestatecallback])
|
||||
return history.history['loss'][-1]
|
||||
|
||||
def train_other(self, other_network, batchsize=1):
|
||||
@ -611,6 +646,7 @@ class TrainingNeuralNetworkDecorator():
|
||||
other_network.compiled()
|
||||
x, y = other_network.net.compute_samples()
|
||||
history = self.net.model.fit(x=x, y=y, verbose=0, batch_size=batchsize)
|
||||
|
||||
return history.history['loss'][-1]
|
||||
|
||||
|
||||
@ -648,17 +684,21 @@ if __name__ == '__main__':
|
||||
if True:
|
||||
# ok so this works quite realiably
|
||||
with FixpointExperiment() as exp:
|
||||
run_count = 1000
|
||||
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2))\
|
||||
.with_params(epsilon=0.0001).with_keras_params(optimizer='sgd')
|
||||
for run_id in tqdm(range(run_count+1)):
|
||||
loss = net.compiled().train()
|
||||
if run_id % 100 == 0:
|
||||
net.print_weights()
|
||||
# print(net.apply_to_network(net))
|
||||
print("Fixpoint? " + str(net.is_fixpoint()))
|
||||
print("Loss " + str(loss))
|
||||
print()
|
||||
for i in range(10):
|
||||
|
||||
run_count = 1000
|
||||
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2))\
|
||||
.with_params(epsilon=0.0001).with_keras_params(optimizer='sgd')
|
||||
for run_id in tqdm(range(run_count+1)):
|
||||
loss = net.compiled().train(epoch=run_id)
|
||||
if run_id % 100 == 0:
|
||||
net.print_weights()
|
||||
# print(net.apply_to_network(net))
|
||||
print("Fixpoint? " + str(net.is_fixpoint()))
|
||||
print("Loss " + str(loss))
|
||||
print()
|
||||
exp.historical_particles[i] = net
|
||||
K.clear_session()
|
||||
if False:
|
||||
# this does not work as the aggregation function screws over the fixpoint computation....
|
||||
# TODO: check for fixpoint in aggregated space...
|
||||
|
Reference in New Issue
Block a user