Visual Debugging
This commit is contained in:
@ -83,6 +83,9 @@ class NeuralNetwork(PrintingObject):
|
||||
def get_weights(self):
|
||||
return self.model.get_weights()
|
||||
|
||||
def get_weights_flat(self):
|
||||
return np.hstack([weight.flatten() for weight in self.get_weights()])
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
return self.model.set_weights(new_weights)
|
||||
|
||||
@ -603,9 +606,9 @@ class TrainingNeuralNetworkDecorator(NeuralNetwork):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if False:
|
||||
if True:
|
||||
with FixpointExperiment() as exp:
|
||||
for run_id in tqdm(range(1)):
|
||||
for run_id in tqdm(range(100)):
|
||||
# net = WeightwiseNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')
|
||||
# net = AggregatingNeuralNetwork(aggregates=4, width=2, depth=2)\
|
||||
net = FFTNeuralNetwork(aggregates=4, width=2, depth=2) \
|
||||
|
Reference in New Issue
Block a user