From ebf133414c1ee18b9b7817c1aadeb0775e6e44d9 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Wed, 23 Feb 2022 12:19:27 +0100 Subject: [PATCH] network test --- sparse_net.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sparse_net.py b/sparse_net.py index a7cc53b..ad422ed 100644 --- a/sparse_net.py +++ b/sparse_net.py @@ -270,20 +270,26 @@ def test_sparse_net(): def test_sparse_net_sef_train(): net = SparseNetwork(30, 5, 6, 10) optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9) - epochs = 120 + optimizer_dict = { + key: torch.optim.SGD(layer.parameters(), lr=0.008, momentum=0.9) for key, layer in enumerate(net.sparselayers) + } + epochs = 1000 + loss_fn = torch.nn.MSELoss(reduction="sum") for _ in trange(epochs): - optimizer.zero_grad() - loss = net.combined_self_train() - - loss.backward(retain_graph=True) - optimizer.step() + for layer, optim in zip(net.sparselayers, optimizer_dict.values()): + optim.zero_grad() + x, target_data = layer.get_self_train_inputs_and_targets() + output = layer(x) + loss = loss_fn(output, target_data) + loss.backward() + optim.step() # is each of the networks self-replicating? counter = defaultdict(lambda: 0) id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) counter = dict(counter) - print(f"identity_fn after {epochs+1} self-train epochs: {counter}") + print(f"identity_fn after {epochs} self-train epochs: {counter}") def test_manual_for_loop(): @@ -307,7 +313,7 @@ def test_manual_for_loop(): if __name__ == '__main__': - # test_sparse_layer() + test_sparse_layer() test_sparse_net_sef_train() # test_sparse_net() # for comparison