model getter fixed
This commit is contained in:
parent
de6aa68f23
commit
20e9545b02
@ -575,7 +575,7 @@ class TrainingNeuralNetworkDecorator(NeuralNetwork):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def get_model(self):
|
def get_model(self):
|
||||||
return self.net.get_model()
|
return self.net.model
|
||||||
|
|
||||||
def apply_to_weights(self, old_weights):
|
def apply_to_weights(self, old_weights):
|
||||||
return self.net.apply_to_weights(old_weights)
|
return self.net.apply_to_weights(old_weights)
|
||||||
@ -606,13 +606,13 @@ class TrainingNeuralNetworkDecorator(NeuralNetwork):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if True:
|
if False:
|
||||||
with FixpointExperiment() as exp:
|
with FixpointExperiment() as exp:
|
||||||
for run_id in tqdm(range(100)):
|
for run_id in tqdm(range(100)):
|
||||||
# net = WeightwiseNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')
|
# net = WeightwiseNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')
|
||||||
# net = AggregatingNeuralNetwork(aggregates=4, width=2, depth=2)\
|
# net = AggregatingNeuralNetwork(aggregates=4, width=2, depth=2)\
|
||||||
net = FFTNeuralNetwork(aggregates=4, width=2, depth=2) \
|
# net = FFTNeuralNetwork(aggregates=4, width=2, depth=2) \
|
||||||
.with_params(print_all_weight_updates=False, use_bias=False)
|
# .with_params(print_all_weight_updates=False, use_bias=False)
|
||||||
# net = RecurrentNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')\
|
# net = RecurrentNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')\
|
||||||
# .with_params(print_all_weight_updates=True)
|
# .with_params(print_all_weight_updates=True)
|
||||||
# net.print_weights()
|
# net.print_weights()
|
||||||
@ -636,10 +636,10 @@ if __name__ == '__main__':
|
|||||||
net.print_weights()
|
net.print_weights()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
print(net.is_fixpoint(epsilon=0.1e-6))
|
print(net.is_fixpoint(epsilon=0.1e-6))
|
||||||
if False:
|
if True:
|
||||||
# ok so this works quite realiably
|
# ok so this works quite realiably
|
||||||
with FixpointExperiment() as exp:
|
with FixpointExperiment() as exp:
|
||||||
run_count = 1000
|
run_count = 100
|
||||||
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2))\
|
net = TrainingNeuralNetworkDecorator(WeightwiseNeuralNetwork(width=2, depth=2))\
|
||||||
.with_params(epsilon=0.0001).with_keras_params(optimizer='sgd')
|
.with_params(epsilon=0.0001).with_keras_params(optimizer='sgd')
|
||||||
for run_id in tqdm(range(run_count+1)):
|
for run_id in tqdm(range(run_count+1)):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user