Visual Debugging

This commit is contained in:
Si11ium
2019-03-05 20:31:33 +01:00
parent 6625481091
commit de6aa68f23
6 changed files with 59 additions and 42 deletions

View File

@ -83,6 +83,9 @@ class NeuralNetwork(PrintingObject):
def get_weights(self):
return self.model.get_weights()
def get_weights_flat(self):
return np.hstack([weight.flatten() for weight in self.get_weights()])
def set_weights(self, new_weights):
return self.model.set_weights(new_weights)
@ -603,9 +606,9 @@ class TrainingNeuralNetworkDecorator(NeuralNetwork):
if __name__ == '__main__':
if False:
if True:
with FixpointExperiment() as exp:
for run_id in tqdm(range(1)):
for run_id in tqdm(range(100)):
# net = WeightwiseNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')
# net = AggregatingNeuralNetwork(aggregates=4, width=2, depth=2)\
net = FFTNeuralNetwork(aggregates=4, width=2, depth=2) \