diff --git a/code/network.py b/code/network.py index 591f7be..e9d3a33 100644 --- a/code/network.py +++ b/code/network.py @@ -173,19 +173,36 @@ class AggregatingNeuralNetwork(NeuralNetwork): def deaggregate_identically(aggregate, amount): return [aggregate for _ in range(amount)] + @staticmethod + def shuffle_not(weights_list): + return weights_list + + @staticmethod + def shuffle_random(weights_list): + import random + random.shuffle(weights_list) + return weights_list + def __init__(self, aggregates, width, depth, **kwargs): super().__init__(**kwargs) + self.aggregates = aggregates self.width = width self.depth = depth - self.aggregates = aggregates - self.aggregator = self.params.get('aggregator', self.__class__.aggregate_average) - self.deaggregator = self.params.get('deaggregator', self.__class__.deaggregate_identically) self.model = Sequential() self.model.add(Dense(units=width, input_dim=self.aggregates, **self.keras_params)) for _ in range(depth-1): self.model.add(Dense(units=width, **self.keras_params)) self.model.add(Dense(units=self.aggregates, **self.keras_params)) + def get_aggregator(self): + return self.params.get('aggregator', self.__class__.aggregate_average) + + def get_deaggregator(self): + return self.params.get('deaggregator', self.__class__.deaggregate_identically) + + def get_shuffler(self): + return self.params.get('shuffler', self.__class__.shuffle_not) + def get_amount_of_weights(self): total_weights = 0 for layer_id,layer in enumerate(self.get_weights()): @@ -215,15 +232,16 @@ class AggregatingNeuralNetwork(NeuralNetwork): collections[-1] += next_collection leftovers = len(next_collection) # call network - old_aggregations = [self.aggregator(collection) for collection in collections] + old_aggregations = [self.get_aggregator()(collection) for collection in collections] new_aggregations = self.apply(*old_aggregations) # generate list of new weights new_weights_list = [] for aggregation_id,aggregation in enumerate(new_aggregations): if aggregation_id == self.aggregates - 1: - new_weights_list += self.deaggregator(aggregation, collection_size + leftovers) + new_weights_list += self.get_deaggregator()(aggregation, collection_size + leftovers) else: - new_weights_list += self.deaggregator(aggregation, collection_size) + new_weights_list += self.get_deaggregator()(aggregation, collection_size) + new_weights_list = self.get_shuffler()(new_weights_list) # write back new weights new_weights = copy.deepcopy(old_weights) current_weight_id = 0 @@ -286,8 +304,8 @@ if __name__ == '__main__': with FixpointExperiment() as exp: for run_id in tqdm(range(100)): # net = WeightwiseNeuralNetwork(2, 2).with_keras_params(activation='linear') - # net = AggregatingNeuralNetwork(4, 2, 2).with_keras_params(activation='linear').with_params(print_all_weight_updates=False) - net = RecurrentNeuralNetwork(2, 2).with_keras_params(activation='linear').with_params(print_all_weight_updates=True) + net = AggregatingNeuralNetwork(4, 2, 2).with_keras_params(activation='linear').with_params(shuffler=AggregatingNeuralNetwork.shuffle_random, print_all_weight_updates=False) + # net = RecurrentNeuralNetwork(2, 2).with_keras_params(activation='linear').with_params(print_all_weight_updates=True) # net.print_weights() exp.run_net(net, 100) exp.log(exp.counters)