This commit is contained in:
Si11ium
2019-03-06 19:21:19 +01:00
parent 6ced18c2d7
commit bae997feab
3 changed files with 69 additions and 43 deletions

View File

@ -1,7 +1,6 @@
import math
import copy
import numpy as np
from tqdm import tqdm
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense
@ -43,6 +42,7 @@ class NeuralNetwork(PrintingObject):
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
# could be a chain comparission "lower_bound <= weight <= upper_bound"
if not (lower_bound <= weight and weight <= upper_bound):
return False
return True
@ -538,6 +538,7 @@ class LearningNeuralNetwork(NeuralNetwork):
self.depth = depth
self.features = features
self.compile_params = dict(loss='mse', optimizer='sgd')
self.model = Sequential()
self.model.add(Dense(units=self.width, input_dim=self.features, **self.keras_params))
for _ in range(self.depth-1):
self.model.add(Dense(units=self.width, **self.keras_params))
@ -591,7 +592,7 @@ class TrainingNeuralNetworkDecorator():
def compile_model(self, **kwargs):
compile_params = copy.deepcopy(self.compile_params)
compile_params.update(kwargs)
return self.get_model().compile(**compile_params)
return self.net.model.compile(**compile_params)
def compiled(self, **kwargs):
if not self.model_compiled:
@ -617,7 +618,7 @@ if __name__ == '__main__':
if False:
with FixpointExperiment() as exp:
for run_id in tqdm(range(100)):
# net = WeightwiseNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')
net = WeightwiseNeuralNetwork(width=2, depth=2).with_keras_params(activation='linear')
# net = AggregatingNeuralNetwork(aggregates=4, width=2, depth=2)\
# net = FFTNeuralNetwork(aggregates=4, width=2, depth=2) \
# .with_params(print_all_weight_updates=False, use_bias=False)