network test

This commit is contained in:
Steffen Illium
2022-02-23 12:19:27 +01:00
parent 0bc3b62340
commit ebf133414c

View File

@ -270,20 +270,26 @@ def test_sparse_net():
def test_sparse_net_sef_train(): def test_sparse_net_sef_train():
net = SparseNetwork(30, 5, 6, 10) net = SparseNetwork(30, 5, 6, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9) optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9)
epochs = 120 optimizer_dict = {
key: torch.optim.SGD(layer.parameters(), lr=0.008, momentum=0.9) for key, layer in enumerate(net.sparselayers)
}
epochs = 1000
loss_fn = torch.nn.MSELoss(reduction="sum")
for _ in trange(epochs): for _ in trange(epochs):
optimizer.zero_grad() for layer, optim in zip(net.sparselayers, optimizer_dict.values()):
loss = net.combined_self_train() optim.zero_grad()
x, target_data = layer.get_self_train_inputs_and_targets()
loss.backward(retain_graph=True) output = layer(x)
optimizer.step() loss = loss_fn(output, target_data)
loss.backward()
optim.step()
# is each of the networks self-replicating? # is each of the networks self-replicating?
counter = defaultdict(lambda: 0) counter = defaultdict(lambda: 0)
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
counter = dict(counter) counter = dict(counter)
print(f"identity_fn after {epochs+1} self-train epochs: {counter}") print(f"identity_fn after {epochs} self-train epochs: {counter}")
def test_manual_for_loop(): def test_manual_for_loop():
@ -307,7 +313,7 @@ def test_manual_for_loop():
if __name__ == '__main__': if __name__ == '__main__':
# test_sparse_layer() test_sparse_layer()
test_sparse_net_sef_train() test_sparse_net_sef_train()
# test_sparse_net() # test_sparse_net()
# for comparison # for comparison