fft Network
This commit is contained in:
107
code/network.py
107
code/network.py
@ -80,14 +80,11 @@ class NeuralNetwork(PrintingObject):
|
||||
self.keras_params.update(kwargs)
|
||||
return self
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def get_weights(self):
|
||||
return self.get_model().get_weights()
|
||||
return self.model.get_weights()
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
return self.get_model().set_weights(new_weights)
|
||||
return self.model.set_weights(new_weights)
|
||||
|
||||
def apply_to_weights(self, old_weights):
|
||||
raise NotImplementedError
|
||||
@ -264,13 +261,13 @@ class AggregatingNeuralNetwork(NeuralNetwork):
|
||||
self.model.add(Dense(units=self.aggregates, **self.keras_params))
|
||||
|
||||
def get_aggregator(self):
|
||||
return self.params.get('aggregator', self.__class__.aggregate_average)
|
||||
return self.params.get('aggregator', self.aggregate_average)
|
||||
|
||||
def get_deaggregator(self):
|
||||
return self.params.get('deaggregator', self.__class__.deaggregate_identically)
|
||||
return self.params.get('deaggregator', self.deaggregate_identically)
|
||||
|
||||
def get_shuffler(self):
|
||||
return self.params.get('shuffler', self.__class__.shuffle_not)
|
||||
return self.params.get('shuffler', self.shuffle_not)
|
||||
|
||||
def get_amount_of_weights(self):
|
||||
total_weights = 0
|
||||
@ -287,7 +284,7 @@ class AggregatingNeuralNetwork(NeuralNetwork):
|
||||
def apply_to_weights(self, old_weights):
|
||||
# build aggregations from old_weights
|
||||
collection_size = self.get_amount_of_weights() // self.aggregates
|
||||
collections, leftovers = self.__class__.collect_weights(old_weights, collection_size)
|
||||
collections, leftovers = self.collect_weights(old_weights, collection_size)
|
||||
|
||||
# call network
|
||||
old_aggregations = [self.get_aggregator()(collection) for collection in collections]
|
||||
@ -303,14 +300,14 @@ class AggregatingNeuralNetwork(NeuralNetwork):
|
||||
new_weights_list = self.get_shuffler()(new_weights_list)
|
||||
|
||||
# write back new weights
|
||||
new_weights = self.__class__.fill_weights(old_weights, new_weights_list)
|
||||
new_weights = self.fill_weights(old_weights, new_weights_list)
|
||||
|
||||
# return results
|
||||
if self.params.get("print_all_weight_updates", False) and not self.is_silent():
|
||||
print("updated old weight aggregations " + str(old_aggregations))
|
||||
print("to new weight aggregations " + str(new_aggregations))
|
||||
print("resulting in network weights ...")
|
||||
print(self.__class__.weights_to_string(new_weights))
|
||||
print(self.weights_to_string(new_weights))
|
||||
return new_weights
|
||||
|
||||
@staticmethod
|
||||
@ -367,6 +364,84 @@ class AggregatingNeuralNetwork(NeuralNetwork):
|
||||
return True, new_aggregations
|
||||
|
||||
|
||||
class FFTNeuralNetwork(NeuralNetwork):
|
||||
|
||||
@staticmethod
|
||||
def aggregate_fft(weights, dims):
|
||||
flat = np.hstack([weight.flatten() for weight in weights])
|
||||
fft_reduction = np.fft.fftn(flat, dims)[None, ...]
|
||||
return fft_reduction
|
||||
|
||||
@staticmethod
|
||||
def deaggregate_identically(aggregate, dims):
|
||||
fft_inverse = np.fft.ifftn(aggregate, dims)
|
||||
return fft_inverse
|
||||
|
||||
@staticmethod
|
||||
def shuffle_not(weights_list):
|
||||
return weights_list
|
||||
|
||||
@staticmethod
|
||||
def shuffle_random(weights_list):
|
||||
import random
|
||||
random.shuffle(weights_list)
|
||||
return weights_list
|
||||
|
||||
def __init__(self, aggregates, width, depth, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.aggregates = aggregates
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.model.add(Dense(units=width, input_dim=self.aggregates, **self.keras_params))
|
||||
for _ in range(depth-1):
|
||||
self.model.add(Dense(units=width, **self.keras_params))
|
||||
self.model.add(Dense(units=self.aggregates, **self.keras_params))
|
||||
|
||||
def get_shuffler(self):
|
||||
return self.params.get('shuffler', self.shuffle_not)
|
||||
|
||||
def get_amount_of_weights(self):
|
||||
total_weights = 0
|
||||
for layer_id, layer in enumerate(self.get_weights()):
|
||||
for cell_id, cell in enumerate(layer):
|
||||
for weight_id, weight in enumerate(cell):
|
||||
total_weights += 1
|
||||
return total_weights
|
||||
|
||||
def apply(self, inputs):
|
||||
sample = np.asarray(inputs)
|
||||
return self.model.predict(sample)[0]
|
||||
|
||||
def apply_to_weights(self, old_weights):
|
||||
# build aggregations from old_weights
|
||||
weights = self.get_weights()
|
||||
|
||||
# call network
|
||||
old_aggregation = self.aggregate_fft(weights, self.aggregates)
|
||||
new_aggregation = self.apply(old_aggregation)
|
||||
|
||||
# generate list of new weights
|
||||
new_weights_list = self.deaggregate_identically(new_aggregation, self.get_amount_of_weights())
|
||||
|
||||
new_weights_list = self.get_shuffler()(new_weights_list)
|
||||
|
||||
# write back new weights
|
||||
new_weights = self.fill_weights(old_weights, new_weights_list)
|
||||
|
||||
# return results
|
||||
if self.params.get("print_all_weight_updates", False) and not self.is_silent():
|
||||
print("updated old weight aggregations " + str(old_aggregation))
|
||||
print("to new weight aggregations " + str(new_aggregation))
|
||||
print("resulting in network weights ...")
|
||||
print(self.__class__.weights_to_string(new_weights))
|
||||
return new_weights
|
||||
|
||||
def compute_samples(self):
|
||||
weights = self.get_weights()
|
||||
sample = np.asarray(weights)[None, ...]
|
||||
return [sample], [sample]
|
||||
|
||||
|
||||
class RecurrentNeuralNetwork(NeuralNetwork):
|
||||
|
||||
def __init__(self, width, depth, **kwargs):
|
||||
@ -472,7 +547,6 @@ class TrainingNeuralNetworkDecorator(NeuralNetwork):
|
||||
def __init__(self, net, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.net = net
|
||||
self.model = None
|
||||
self.compile_params = dict(loss='mse', optimizer='sgd')
|
||||
self.model_compiled = False
|
||||
|
||||
@ -533,12 +607,11 @@ if __name__ == '__main__':
|
||||
with FixpointExperiment() as exp:
|
||||
for run_id in tqdm(range(100)):
|
||||
# net = WeightwiseNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')
|
||||
net = AggregatingNeuralNetwork(aggregates=4, width=2, depth=2).with_keras_params(activation='linear')\
|
||||
.with_params(shuffler=AggregatingNeuralNetwork.shuffle_random,
|
||||
print_all_weight_updates=False, use_bias=True)
|
||||
# net = AggregatingNeuralNetwork(aggregates=4, width=2, depth=2)\
|
||||
net = FFTNeuralNetwork(aggregates=4, width=2, depth=2) \
|
||||
.with_params(print_all_weight_updates=False, use_bias=False)
|
||||
# net = RecurrentNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')\
|
||||
# .with_params(print_all_weight_updates=True)
|
||||
|
||||
# net.print_weights()
|
||||
exp.run_net(net, 100)
|
||||
exp.log(exp.counters)
|
||||
@ -610,7 +683,7 @@ if __name__ == '__main__':
|
||||
# and this gets somewhat interesting... we can still achieve non-trivial fixpoints
|
||||
# over multiple applications when training enough in-between
|
||||
with MixedFixpointExperiment() as exp:
|
||||
for run_id in range(1):
|
||||
for run_id in range(100):
|
||||
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2))\
|
||||
.with_params(epsilon=0.0001)
|
||||
exp.run_net(net, 500, 10)
|
||||
|
Reference in New Issue
Block a user