From 3da00c793bb1384717b9dc1fbd9d818b9ed1b4a2 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Wed, 23 Feb 2022 18:23:00 +0100 Subject: [PATCH] new sanity methode --- experiments/meta_task_exp.py | 14 +++++++---- network.py | 8 +++++- sanity_check_weights.py | 18 +++++++------ sparse_net.py | 49 ++++++++++++++++++++++-------------- 4 files changed, 57 insertions(+), 32 deletions(-) diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index 1ae1281..c7fb8ad 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -45,7 +45,7 @@ from functionalities_test import test_for_fixpoints WORKER = 10 if not debug else 2 debug = False BATCHSIZE = 500 if not debug else 50 -EPOCH = 100 +EPOCH = 50 VALIDATION_FRQ = 3 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -279,9 +279,9 @@ if __name__ == '__main__': self_train = True training = True - train_to_id_first = True + train_to_id_first = False train_to_task_first = False - train_to_task_first_sequential = False + train_to_task_first_sequential = True force_st_for_n_from_last_epochs = 5 use_sparse_network = False @@ -303,10 +303,12 @@ if __name__ == '__main__': # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' id_str = f'{f"_StToId" if train_to_id_first else ""}' tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}' + sprs_str = '_sprs' if use_sparse_network else '' f_str = f'_f_{force_st_for_n_from_last_epochs}' if \ force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \ else "" - exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}{f_str}' + config_str = f'{a_str}{res_str}{id_str}{tsk_str}{f_str}{sprs_str}' + exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}' for seed in range(n_seeds): seed_path = exp_path / str(seed) @@ -358,8 +360,8 @@ if __name__ == '__main__': force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch) ) and train_to_task_first_sequential and force_st_for_n_from_last_epochs for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): - # Self Train + # Self Train if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st): # Transfer weights if use_sparse_network: @@ -376,6 +378,8 @@ if __name__ == '__main__': # Transfer weights if use_sparse_network: dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) + + # Task Train if not init_st: # Zero your gradients for every batch! dense_optimizer.zero_grad() diff --git a/network.py b/network.py index 053194c..b09a9fc 100644 --- a/network.py +++ b/network.py @@ -11,6 +11,11 @@ from torch import optim, Tensor from tqdm import tqdm +def xavier_init(m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight.data) + + def prng(): return random.random() @@ -97,6 +102,7 @@ class Net(nn.Module): ) self._weight_pos_enc_and_mask = None + self.apply(xavier_init) @property def _weight_pos_enc(self): @@ -503,7 +509,7 @@ class MetaNetCompareBaseline(nn.Module): if __name__ == '__main__': - metanet = MetaNet(interface=3, depth=5, width=3, out=1, dropout=0.0, residual_skip=True) + metanet = MetaNet(interface=3, depth=5, width=3, out=1, residual_skip=True) next(metanet.particles).input_weight_matrix() metanet(torch.hstack([torch.full((2, 1), 1.0) for _ in range(metanet.interface)])) a = metanet.particles diff --git a/sanity_check_weights.py b/sanity_check_weights.py index 05ec507..d46457f 100644 --- a/sanity_check_weights.py +++ b/sanity_check_weights.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from tqdm import tqdm import pandas as pd from pathlib import Path @@ -15,18 +17,20 @@ def extract_weights_from_model(model:MetaNet)->dict: inpt[-1] = 1 inpt.long() - weights = {i:[] for i in range(model.depth)} + weights = defaultdict(list) layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]] - for i,layer in enumerate(layers): + for i, layer in enumerate(layers): for net in layer: weights[i].append(net(inpt).detach()) - return weights + return dict(weights) -def test_weights_as_model(model, weights:dict, data): + +def test_weights_as_model(model, new_weights:dict, data): TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out) + with torch.no_grad(): - for i, weight_set in weights.items(): - TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape)) + for weights, parameters in zip(new_weights.values(), TransferNet.parameters()): + parameters[:] = torch.Tensor(weights).view(parameters.shape)[:] TransferNet.eval() metric = torchmetrics.Accuracy() @@ -56,7 +60,7 @@ if __name__ == '__main__': d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) loss_fn = nn.CrossEntropyLoss() - model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval() + model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval() weights = extract_weights_from_model(model) test_weights_as_model(model, weights, d_test) diff --git a/sparse_net.py b/sparse_net.py index ad422ed..c9fc8e4 100644 --- a/sparse_net.py +++ b/sparse_net.py @@ -120,7 +120,7 @@ class SparseLayer(nn.Module): def test_sparse_layer(): net = SparseLayer(500) #50 parallel nets - loss_fn = torch.nn.MSELoss(reduction="sum") + loss_fn = torch.nn.MSELoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) # optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9) @@ -138,9 +138,10 @@ def test_sparse_layer(): loss.backward() optimizer.step() - epsilon = pow(10, -5) - # is each of the networks self-replicating? - print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}") + counter = defaultdict(lambda: 0) + id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) + counter = dict(counter) + print(f"identity_fn after {train_iteration + 1} self-train epochs: {counter}") def embed_batch(x, repeat_dim): @@ -239,7 +240,7 @@ class SparseNetwork(nn.Module): x, target_data = layer.get_self_train_inputs_and_targets() output = layer(x) - losses.append(F.mse_loss(output, target_data)) + losses.append(F.mse_loss(output, target_data) / layer.nr_nets) return torch.hstack(losses).sum(dim=-1, keepdim=True) def replace_weights_by_particles(self, particles): @@ -269,21 +270,31 @@ def test_sparse_net(): def test_sparse_net_sef_train(): net = SparseNetwork(30, 5, 6, 10) - optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9) - optimizer_dict = { - key: torch.optim.SGD(layer.parameters(), lr=0.008, momentum=0.9) for key, layer in enumerate(net.sparselayers) - } epochs = 1000 - loss_fn = torch.nn.MSELoss(reduction="sum") - - for _ in trange(epochs): - for layer, optim in zip(net.sparselayers, optimizer_dict.values()): - optim.zero_grad() - x, target_data = layer.get_self_train_inputs_and_targets() - output = layer(x) - loss = loss_fn(output, target_data) + if True: + optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) + for _ in trange(epochs): + optimizer.zero_grad() + loss = net.combined_self_train() + print(loss) + exit() loss.backward() - optim.step() + optimizer.step() + + else: + optimizer_dict = { + key: torch.optim.SGD(layer.parameters(), lr=0.004, momentum=0.9) for key, layer in enumerate(net.sparselayers) + } + loss_fn = torch.nn.MSELoss(reduction="mean") + + for layer, optim in zip(net.sparselayers, optimizer_dict.values()): + for _ in trange(epochs): + optim.zero_grad() + x, target_data = layer.get_self_train_inputs_and_targets() + output = layer(x) + loss = loss_fn(output, target_data) + loss.backward() + optim.step() # is each of the networks self-replicating? counter = defaultdict(lambda: 0) @@ -313,7 +324,7 @@ def test_manual_for_loop(): if __name__ == '__main__': - test_sparse_layer() + # test_sparse_layer() test_sparse_net_sef_train() # test_sparse_net() # for comparison