tried to implement a method for additional insight

This commit is contained in:
Thomas Gabor
2019-03-05 05:52:12 +01:00
parent 751c2480fa
commit 18c84d1483

View File

@ -346,6 +346,30 @@ class AggregatingNeuralNetwork(NeuralNetwork):
aggregations, _ = self.get_aggregated_weights()
sample = np.transpose(np.array([[aggregations[i]] for i in range(self.aggregates)]))
return [sample], [sample]
def is_fixpoint_after_aggregation(self, degree=1, epsilon=None):
assert degree >= 1, "degree must be >= 1"
epsilon = epsilon or self.get_params().get('epsilon')
old_weights = self.get_weights()
old_aggregations, _ = self.get_aggregated_weights()
new_weights = copy.deepcopy(old_weights)
for _ in range(degree):
new_weights = self.apply_to_weights(new_weights)
if NeuralNetwork.are_weights_diverged(new_weights):
return False
collection_size = self.get_amount_of_weights() // self.aggregates
collections, leftovers = self.__class__.collect_weights(new_weights, collection_size)
new_aggregations = [self.get_aggregator()(collection) for collection in collections]
for aggregation_id,old_aggregation in enumerate(old_aggregations):
new_aggregation = new_aggregations[aggregation_id]
if abs(new_aggregation - old_aggregation) >= epsilon:
return False, new_aggregations
return True, new_aggregations
class RecurrentNeuralNetwork(NeuralNetwork):
@ -547,7 +571,7 @@ if __name__ == '__main__':
print("Fixpoint? " + str(net.is_fixpoint()))
print("Loss " + str(loss))
print()
if False: # this does not work as the aggregation function screws over the fixpoint computation.... TODO: check for fixpoint in aggregated space...
if True: # this does not work as the aggregation function screws over the fixpoint computation.... TODO: check for fixpoint in aggregated space...
with FixpointExperiment() as exp:
run_count = 1000
net = TrainingNeuralNetworkDecorator(AggregatingNeuralNetwork(4, width=2, depth=2)).with_params(epsilon=0.1e-6)
@ -555,8 +579,12 @@ if __name__ == '__main__':
loss = net.compiled().train()
if run_id % 100 == 0:
net.print_weights()
# print(net.apply_to_network(net))
print("Fixpoint? " + str(net.is_fixpoint(epsilon=0.0001)))
old_aggs, _ = net.net.get_aggregated_weights()
print("old weights agg: " + str(old_aggs))
fp, new_aggs = net.net.is_fixpoint_after_aggregation(epsilon=0.0001)
print("new weights agg: " + str(new_aggs))
print("Fixpoint? " + str(net.is_fixpoint()))
print("Fixpoint after Agg? " + str(fp))
print("Loss " + str(loss))
print()
if False: # this explodes in our faces completely... NAN everywhere TODO: Wtf is happening here?
@ -571,7 +599,7 @@ if __name__ == '__main__':
print("Fixpoint? " + str(net.is_fixpoint(epsilon=0.0001)))
print("Loss " + str(loss))
print()
if True: # and this gets somewhat interesting... we can still achieve non-trivial fixpoints over multiple applications when training enough in-between
if False: # and this gets somewhat interesting... we can still achieve non-trivial fixpoints over multiple applications when training enough in-between
with MixedFixpointExperiment() as exp:
for run_id in range(1):
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2)).with_params(epsilon=0.0001)