some refactoring

This commit is contained in:
Thomas Gabor 2019-03-05 05:31:38 +01:00
parent 19e4ed65f9
commit 751c2480fa
2 changed files with 128 additions and 142 deletions

View File

@ -7,79 +7,13 @@ from tqdm import tqdm
from keras.models import Sequential from keras.models import Sequential
from keras.layers import SimpleRNN, Dense from keras.layers import SimpleRNN, Dense
from util import *
from experiment import * from experiment import *
# Supress warnings and info messages # Supress warnings and info messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def normalize_id(value, norm):
if norm > 1:
return float(value) / float(norm)
else:
return float(value)
def are_weights_diverged(network_weights):
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
if math.isnan(weight):
return True
if math.isinf(weight):
return True
return False
def are_weights_within(network_weights, lower_bound, upper_bound):
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
if not (lower_bound <= weight <= upper_bound):
return False
return True
class PrintingObject():
class SilenceSignal():
def __init__(self, obj, value):
self.obj = obj
self.new_silent = value
def __enter__(self):
self.old_silent = self.obj.get_silence()
self.obj.set_silence(self.new_silent)
def __exit__(self, exception_type, exception_value, traceback):
self.obj.set_silence(self.old_silent)
def __init__(self):
self.silent = True
def is_silent(self):
return self.silent
def get_silence(self):
return self.is_silent()
def set_silence(self, value=True):
self.silent = value
return self
def unset_silence(self):
self.silent = False
return self
def with_silence(self, value=True):
self.set_silence(value)
return self
def silence(self, value=True):
return self.__class__.SilenceSignal(self, value)
def _print(self, *args, **kwargs):
if not self.silent:
print(*args, **kwargs)
class NeuralNetwork(PrintingObject): class NeuralNetwork(PrintingObject):
@staticmethod @staticmethod
@ -94,6 +28,38 @@ class NeuralNetwork(PrintingObject):
s += "\n" s += "\n"
return s return s
@staticmethod
def are_weights_diverged(network_weights):
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
if math.isnan(weight):
return True
if math.isinf(weight):
return True
return False
@staticmethod
def are_weights_within(network_weights, lower_bound, upper_bound):
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
if not (lower_bound <= weight <= upper_bound):
return False
return True
@staticmethod
def fill_weights(old_weights, new_weights_list):
new_weights = copy.deepcopy(old_weights)
current_weight_id = 0
for layer_id, layer in enumerate(new_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
new_weight = new_weights_list[current_weight_id]
new_weights[layer_id][cell_id][weight_id] = new_weight
current_weight_id += 1
return new_weights
def __init__(self, **params): def __init__(self, **params):
super().__init__() super().__init__()
self.model = Sequential() self.model = Sequential()
@ -125,8 +91,7 @@ class NeuralNetwork(PrintingObject):
return self.get_model().set_weights(new_weights) return self.get_model().set_weights(new_weights)
def apply_to_weights(self, old_weights): def apply_to_weights(self, old_weights):
# placeholder, overwrite in subclass raise NotImplementedException
return old_weights
def apply_to_network(self, other_network): def apply_to_network(self, other_network):
new_weights = self.apply_to_weights(other_network.get_weights()) new_weights = self.apply_to_weights(other_network.get_weights())
@ -150,11 +115,11 @@ class NeuralNetwork(PrintingObject):
return new_me.self_attack(iterations) return new_me.self_attack(iterations)
def is_diverged(self): def is_diverged(self):
return are_weights_diverged(self.get_weights()) return NeuralNetwork.are_weights_diverged(self.get_weights())
def is_zero(self, epsilon=None): def is_zero(self, epsilon=None):
epsilon = epsilon or self.params.get('epsilon') epsilon = epsilon or self.params.get('epsilon')
return are_weights_within(self.get_weights(), -epsilon, epsilon) return NeuralNetwork.are_weights_within(self.get_weights(), -epsilon, epsilon)
def is_fixpoint(self, degree=1, epsilon=None): def is_fixpoint(self, degree=1, epsilon=None):
assert degree >= 1, "degree must be >= 1" assert degree >= 1, "degree must be >= 1"
@ -165,7 +130,7 @@ class NeuralNetwork(PrintingObject):
for _ in range(degree): for _ in range(degree):
new_weights = self.apply_to_weights(new_weights) new_weights = self.apply_to_weights(new_weights)
if are_weights_diverged(new_weights): if NeuralNetwork.are_weights_diverged(new_weights):
return False return False
for layer_id, layer in enumerate(old_weights): for layer_id, layer in enumerate(old_weights):
for cell_id, cell in enumerate(layer): for cell_id, cell in enumerate(layer):
@ -184,6 +149,13 @@ class NeuralNetwork(PrintingObject):
class WeightwiseNeuralNetwork(NeuralNetwork): class WeightwiseNeuralNetwork(NeuralNetwork):
@staticmethod
def normalize_id(value, norm):
if norm > 1:
return float(value) / float(norm)
else:
return float(value)
def __init__(self, width, depth, **kwargs): def __init__(self, width, depth, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.width = width self.width = width
@ -197,22 +169,40 @@ class WeightwiseNeuralNetwork(NeuralNetwork):
stuff = np.transpose(np.array([[inputs[0]], [inputs[1]], [inputs[2]], [inputs[3]]])) stuff = np.transpose(np.array([[inputs[0]], [inputs[1]], [inputs[2]], [inputs[3]]]))
return self.model.predict(stuff)[0][0] return self.model.predict(stuff)[0][0]
def apply_to_weights(self, old_weights): @classmethod
new_weights = copy.deepcopy(old_weights) def compute_all_duplex_weight_points(cls, old_weights):
points = []
normal_points = []
max_layer_id = len(old_weights) - 1 max_layer_id = len(old_weights) - 1
for layer_id, layer in enumerate(old_weights): for layer_id, layer in enumerate(old_weights):
max_cell_id = len(layer) - 1 max_cell_id = len(layer) - 1
for cell_id, cell in enumerate(layer): for cell_id, cell in enumerate(layer):
max_weight_id = len(cell) - 1 max_weight_id = len(cell) - 1
for weight_id, weight in enumerate(cell): for weight_id, weight in enumerate(cell):
normal_layer_id = normalize_id(layer_id, max_layer_id) normal_layer_id = cls.normalize_id(layer_id, max_layer_id)
normal_cell_id = normalize_id(cell_id, max_cell_id) normal_cell_id = cls.normalize_id(cell_id, max_cell_id)
normal_weight_id = normalize_id(weight_id, max_weight_id) normal_weight_id = cls.normalize_id(weight_id, max_weight_id)
new_weight = self.apply(weight, normal_layer_id, normal_cell_id, normal_weight_id) points += [[weight, layer_id, cell_id, weight_id]]
normal_points += [[weight, normal_layer_id, normal_cell_id, normal_weight_id]]
return points, normal_points
@classmethod
def compute_all_weight_points(cls, all_weights):
return cls.compute_all_duplex_weight_points(all_weights)[0]
@classmethod
def compute_all_normal_weight_points(cls, all_weights):
return cls.compute_all_duplex_weight_points(all_weights)[1]
def apply_to_weights(self, old_weights):
new_weights = copy.deepcopy(self.get_weights())
for (weight_point, normal_weight_point) in zip(*self.__class__.compute_all_duplex_weight_points(old_weights)):
weight, layer_id, cell_id, weight_id = weight_point
_, normal_layer_id, normal_cell_id, normal_weight_id = normal_weight_point
new_weight = self.apply(*normal_weight_point)
new_weights[layer_id][cell_id][weight_id] = new_weight new_weights[layer_id][cell_id][weight_id] = new_weight
if self.params.get("print_all_weight_updates", False) and not self.is_silent(): if self.params.get("print_all_weight_updates", False) and not self.is_silent():
@ -224,16 +214,9 @@ class WeightwiseNeuralNetwork(NeuralNetwork):
def compute_samples(self): def compute_samples(self):
samples = [] samples = []
new_weights = copy.deepcopy(self.get_weights()) for normal_weight_point in self.__class__.compute_all_normal_weight_points(self.get_weights()):
max_layer_id = len(self.get_weights()) - 1 weight, normal_layer_id, normal_cell_id, normal_weight_id = normal_weight_point
for layer_id, layer in enumerate(self.get_weights()):
max_cell_id = len(layer) - 1
for cell_id, cell in enumerate(layer):
max_weight_id = len(cell) - 1
for weight_id, weight in enumerate(cell):
normal_layer_id = normalize_id(layer_id, max_layer_id)
normal_cell_id = normalize_id(cell_id, max_cell_id)
normal_weight_id = normalize_id(weight_id, max_weight_id)
sample = np.transpose(np.array([[weight], [normal_layer_id], [normal_cell_id], [normal_weight_id]])) sample = np.transpose(np.array([[weight], [normal_layer_id], [normal_cell_id], [normal_weight_id]]))
samples += [sample[0]] samples += [sample[0]]
samples_array = np.asarray(samples) samples_array = np.asarray(samples)
@ -307,19 +290,7 @@ class AggregatingNeuralNetwork(NeuralNetwork):
def apply_to_weights(self, old_weights): def apply_to_weights(self, old_weights):
# build aggregations from old_weights # build aggregations from old_weights
collection_size = self.get_amount_of_weights() // self.aggregates collection_size = self.get_amount_of_weights() // self.aggregates
collections = [] collections, leftovers = self.__class__.collect_weights(old_weights, collection_size)
next_collection = []
current_weight_id = 0
for layer_id, layer in enumerate(old_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
next_collection += [weight]
if (current_weight_id + 1) % collection_size == 0:
collections += [next_collection]
next_collection = []
current_weight_id += 1
collections[-1] += next_collection
leftovers = len(next_collection)
# call network # call network
old_aggregations = [self.get_aggregator()(collection) for collection in collections] old_aggregations = [self.get_aggregator()(collection) for collection in collections]
@ -335,14 +306,7 @@ class AggregatingNeuralNetwork(NeuralNetwork):
new_weights_list = self.get_shuffler()(new_weights_list) new_weights_list = self.get_shuffler()(new_weights_list)
# write back new weights # write back new weights
new_weights = copy.deepcopy(old_weights) new_weights = self.__class__.fill_weights(old_weights, new_weights_list)
current_weight_id = 0
for layer_id, layer in enumerate(new_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
new_weight = new_weights_list[current_weight_id]
new_weights[layer_id][cell_id][weight_id] = new_weight
current_weight_id += 1
# return results # return results
if self.params.get("print_all_weight_updates", False) and not self.is_silent(): if self.params.get("print_all_weight_updates", False) and not self.is_silent():
@ -383,25 +347,6 @@ class AggregatingNeuralNetwork(NeuralNetwork):
sample = np.transpose(np.array([[aggregations[i]] for i in range(self.aggregates)])) sample = np.transpose(np.array([[aggregations[i]] for i in range(self.aggregates)]))
return [sample], [sample] return [sample], [sample]
def is_fixpoint(self, degree=1, epsilon=None):
assert degree >= 1, "degree must be >= 1"
epsilon = epsilon or self.get_params().get('epsilon')
old_weights = self.get_weights()
new_weights = copy.deepcopy(old_weights)
for _ in range(degree):
new_weights = self.apply_to_weights(new_weights)
if are_weights_diverged(new_weights):
return False
for layer_id, layer in enumerate(old_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
new_weight = new_weights[layer_id][cell_id][weight_id]
if abs(new_weight - weight) >= epsilon:
return False
return True
class RecurrentNeuralNetwork(NeuralNetwork): class RecurrentNeuralNetwork(NeuralNetwork):
@ -503,6 +448,8 @@ class LearningNeuralNetwork(NeuralNetwork):
bar.postfix[1]["value"] = history.history['loss'][-1] bar.postfix[1]["value"] = history.history['loss'][-1]
bar.update() bar.update()
class TrainingNeuralNetworkDecorator(NeuralNetwork): class TrainingNeuralNetworkDecorator(NeuralNetwork):
def __init__(self, net, **kwargs): def __init__(self, net, **kwargs):
@ -591,13 +538,13 @@ if __name__ == '__main__':
if False: # ok so this works quite realiably if False: # ok so this works quite realiably
with FixpointExperiment() as exp: with FixpointExperiment() as exp:
run_count = 1000 run_count = 1000
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2)).with_params(epsilon=0.1e-6) net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2)).with_params(epsilon=0.0001)
for run_id in tqdm(range(run_count+1)): for run_id in tqdm(range(run_count+1)):
loss = net.compiled().train() loss = net.compiled().train()
if run_id % 100 == 0: if run_id % 100 == 0:
net.print_weights() net.print_weights()
# print(net.apply_to_network(net)) # print(net.apply_to_network(net))
print("Fixpoint? " + str(net.is_fixpoint(epsilon=0.0001))) print("Fixpoint? " + str(net.is_fixpoint()))
print("Loss " + str(loss)) print("Loss " + str(loss))
print() print()
if False: # this does not work as the aggregation function screws over the fixpoint computation.... TODO: check for fixpoint in aggregated space... if False: # this does not work as the aggregation function screws over the fixpoint computation.... TODO: check for fixpoint in aggregated space...

39
code/util.py Normal file
View File

@ -0,0 +1,39 @@
class PrintingObject:
class SilenceSignal():
def __init__(self, obj, value):
self.obj = obj
self.new_silent = value
def __enter__(self):
self.old_silent = self.obj.get_silence()
self.obj.set_silence(self.new_silent)
def __exit__(self, exception_type, exception_value, traceback):
self.obj.set_silence(self.old_silent)
def __init__(self):
self.silent = True
def is_silent(self):
return self.silent
def get_silence(self):
return self.is_silent()
def set_silence(self, value=True):
self.silent = value
return self
def unset_silence(self):
self.silent = False
return self
def with_silence(self, value=True):
self.set_silence(value)
return self
def silence(self, value=True):
return self.__class__.SilenceSignal(self, value)
def _print(self, *args, **kwargs):
if not self.silent:
print(*args, **kwargs)