tried to implement a method for additional insight
This commit is contained in:
@@ -347,6 +347,30 @@ class AggregatingNeuralNetwork(NeuralNetwork):
|
|||||||
sample = np.transpose(np.array([[aggregations[i]] for i in range(self.aggregates)]))
|
sample = np.transpose(np.array([[aggregations[i]] for i in range(self.aggregates)]))
|
||||||
return [sample], [sample]
|
return [sample], [sample]
|
||||||
|
|
||||||
|
def is_fixpoint_after_aggregation(self, degree=1, epsilon=None):
|
||||||
|
assert degree >= 1, "degree must be >= 1"
|
||||||
|
epsilon = epsilon or self.get_params().get('epsilon')
|
||||||
|
|
||||||
|
old_weights = self.get_weights()
|
||||||
|
old_aggregations, _ = self.get_aggregated_weights()
|
||||||
|
|
||||||
|
new_weights = copy.deepcopy(old_weights)
|
||||||
|
for _ in range(degree):
|
||||||
|
new_weights = self.apply_to_weights(new_weights)
|
||||||
|
if NeuralNetwork.are_weights_diverged(new_weights):
|
||||||
|
return False
|
||||||
|
collection_size = self.get_amount_of_weights() // self.aggregates
|
||||||
|
collections, leftovers = self.__class__.collect_weights(new_weights, collection_size)
|
||||||
|
new_aggregations = [self.get_aggregator()(collection) for collection in collections]
|
||||||
|
|
||||||
|
for aggregation_id,old_aggregation in enumerate(old_aggregations):
|
||||||
|
new_aggregation = new_aggregations[aggregation_id]
|
||||||
|
if abs(new_aggregation - old_aggregation) >= epsilon:
|
||||||
|
return False, new_aggregations
|
||||||
|
return True, new_aggregations
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RecurrentNeuralNetwork(NeuralNetwork):
|
class RecurrentNeuralNetwork(NeuralNetwork):
|
||||||
|
|
||||||
@@ -547,7 +571,7 @@ if __name__ == '__main__':
|
|||||||
print("Fixpoint? " + str(net.is_fixpoint()))
|
print("Fixpoint? " + str(net.is_fixpoint()))
|
||||||
print("Loss " + str(loss))
|
print("Loss " + str(loss))
|
||||||
print()
|
print()
|
||||||
if False: # this does not work as the aggregation function screws over the fixpoint computation.... TODO: check for fixpoint in aggregated space...
|
if True: # this does not work as the aggregation function screws over the fixpoint computation.... TODO: check for fixpoint in aggregated space...
|
||||||
with FixpointExperiment() as exp:
|
with FixpointExperiment() as exp:
|
||||||
run_count = 1000
|
run_count = 1000
|
||||||
net = TrainingNeuralNetworkDecorator(AggregatingNeuralNetwork(4, width=2, depth=2)).with_params(epsilon=0.1e-6)
|
net = TrainingNeuralNetworkDecorator(AggregatingNeuralNetwork(4, width=2, depth=2)).with_params(epsilon=0.1e-6)
|
||||||
@@ -555,8 +579,12 @@ if __name__ == '__main__':
|
|||||||
loss = net.compiled().train()
|
loss = net.compiled().train()
|
||||||
if run_id % 100 == 0:
|
if run_id % 100 == 0:
|
||||||
net.print_weights()
|
net.print_weights()
|
||||||
# print(net.apply_to_network(net))
|
old_aggs, _ = net.net.get_aggregated_weights()
|
||||||
print("Fixpoint? " + str(net.is_fixpoint(epsilon=0.0001)))
|
print("old weights agg: " + str(old_aggs))
|
||||||
|
fp, new_aggs = net.net.is_fixpoint_after_aggregation(epsilon=0.0001)
|
||||||
|
print("new weights agg: " + str(new_aggs))
|
||||||
|
print("Fixpoint? " + str(net.is_fixpoint()))
|
||||||
|
print("Fixpoint after Agg? " + str(fp))
|
||||||
print("Loss " + str(loss))
|
print("Loss " + str(loss))
|
||||||
print()
|
print()
|
||||||
if False: # this explodes in our faces completely... NAN everywhere TODO: Wtf is happening here?
|
if False: # this explodes in our faces completely... NAN everywhere TODO: Wtf is happening here?
|
||||||
@@ -571,7 +599,7 @@ if __name__ == '__main__':
|
|||||||
print("Fixpoint? " + str(net.is_fixpoint(epsilon=0.0001)))
|
print("Fixpoint? " + str(net.is_fixpoint(epsilon=0.0001)))
|
||||||
print("Loss " + str(loss))
|
print("Loss " + str(loss))
|
||||||
print()
|
print()
|
||||||
if True: # and this gets somewhat interesting... we can still achieve non-trivial fixpoints over multiple applications when training enough in-between
|
if False: # and this gets somewhat interesting... we can still achieve non-trivial fixpoints over multiple applications when training enough in-between
|
||||||
with MixedFixpointExperiment() as exp:
|
with MixedFixpointExperiment() as exp:
|
||||||
for run_id in range(1):
|
for run_id in range(1):
|
||||||
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2)).with_params(epsilon=0.0001)
|
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2)).with_params(epsilon=0.0001)
|
||||||
|
|||||||
Reference in New Issue
Block a user