diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index 60852d1..1ae1281 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -279,9 +279,12 @@ if __name__ == '__main__': self_train = True training = True - train_to_id_first = False + train_to_id_first = True train_to_task_first = False - train_to_task_first_sequential = True + train_to_task_first_sequential = False + force_st_for_n_from_last_epochs = 5 + + use_sparse_network = False tsk_threshold = 0.855 self_train_alpha = 1 @@ -299,8 +302,11 @@ if __name__ == '__main__': res_str = f'{"" if residual_skip else "_no_res"}' # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' id_str = f'{f"_StToId" if train_to_id_first else ""}' - tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first else ""}' - exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}' + tsk_str = f'{f"_Tsk_{tsk_threshold}" if train_to_task_first and tsk_threshold != 1 else ""}' + f_str = f'_f_{force_st_for_n_from_last_epochs}' if \ + force_st_for_n_from_last_epochs and train_to_task_first_sequential and train_to_task_first \ + else "" + exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{id_str}{tsk_str}{f_str}' for seed in range(n_seeds): seed_path = exp_path / str(seed) @@ -309,6 +315,8 @@ if __name__ == '__main__': df_store_path = seed_path / 'train_store.csv' weight_store_path = seed_path / 'weight_store.csv' srnn_parameters = dict() + for path in [model_path, df_store_path, weight_store_path]: + assert not path.exists(), f'Path "{path}" already exists. Check your configuration!' if training: utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) @@ -319,15 +327,18 @@ if __name__ == '__main__': d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) interface = np.prod(dataset[0][0].shape) - sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip, - weight_hidden_size=weight_hidden_size,).to(DEVICE) dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,).to(DEVICE) + sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip, + weight_hidden_size=weight_hidden_size + ).to(DEVICE) if use_sparse_network else dense_metanet meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters()) loss_fn = nn.CrossEntropyLoss() dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.008, momentum=0.9) - sparse_optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.008, momentum=0.9) + sparse_optimizer = torch.optim.SGD( + sparse_metanet.parameters(), lr=0.008, momentum=0.9 + ) if use_sparse_network else dense_optimizer train_store = new_storage_df('train', None) weight_store = new_storage_df('weights', meta_weight_count) @@ -341,12 +352,18 @@ if __name__ == '__main__': metric = torchmetrics.Accuracy() else: metric = None - init_st = train_to_id_first and not all(x.is_fixpoint == ft.identity_func for x in dense_metanet.particles) + init_st = train_to_id_first and not all( + x.is_fixpoint == ft.identity_func for x in dense_metanet.particles + ) + force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch) + ) and train_to_task_first_sequential and force_st_for_n_from_last_epochs for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): # Self Train - if self_train and not init_tsk and (is_self_train_epoch or init_st): + + if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st): # Transfer weights - sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) + if use_sparse_network: + sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) # Zero your gradients for every batch! sparse_optimizer.zero_grad() self_train_loss = sparse_metanet.combined_self_train() * self_train_alpha @@ -357,7 +374,8 @@ if __name__ == '__main__': Metric='Self Train Loss', Score=self_train_loss.item()) train_store.loc[train_store.shape[0]] = step_log # Transfer weights - dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) + if use_sparse_network: + dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) if not init_st: # Zero your gradients for every batch! dense_optimizer.zero_grad() @@ -381,7 +399,7 @@ if __name__ == '__main__': if is_validation_epoch: dense_metanet = dense_metanet.eval() - if train_to_id_first <= epoch: + if not init_st: validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric='Train Accuracy', Score=metric.compute().item()) train_store.loc[train_store.shape[0]] = validation_log @@ -438,9 +456,14 @@ if __name__ == '__main__': print(f'Found Models are: {list(seed_path.rglob(".tp"))}') exit(1) latest_model = torch.load(model_path, map_location=DEVICE).eval() - - run_particle_dropout_and_plot(seed_path) - plot_network_connectivity_by_fixtype(model_path) + try: + run_particle_dropout_and_plot(seed_path) + except ValueError as e: + print(e) + try: + plot_network_connectivity_by_fixtype(model_path) + except ValueError as e: + print(e) if n_seeds >= 2: pass diff --git a/sanity_check_weights.py b/sanity_check_weights.py index 7fc8f18..05ec507 100644 --- a/sanity_check_weights.py +++ b/sanity_check_weights.py @@ -48,7 +48,7 @@ if __name__ == '__main__': DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') WORKER = 0 BATCHSIZE = 500 - MNIST_TRANSFORM = Compose([ Resize((15, 15)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)]) + MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)]) torch.manual_seed(42) data_path = Path('data') data_path.mkdir(exist_ok=True, parents=True) diff --git a/sparse_net.py b/sparse_net.py index 004840e..a7cc53b 100644 --- a/sparse_net.py +++ b/sparse_net.py @@ -1,5 +1,8 @@ +from collections import defaultdict + from torch import nn +import functionalities_test from network import Net from functionalities_test import is_identity_function from tqdm import tqdm,trange @@ -118,12 +121,12 @@ class SparseLayer(nn.Module): def test_sparse_layer(): net = SparseLayer(500) #50 parallel nets loss_fn = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(net.weights, lr=0.004, momentum=0.9) + optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) # optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9) for train_iteration in trange(1000): optimizer.zero_grad() - X,Y = net.get_self_train_inputs_and_targets() + X, Y = net.get_self_train_inputs_and_targets() out = net(X) loss = loss_fn(out, Y) @@ -132,10 +135,10 @@ def test_sparse_layer(): # print("OUT", out.shape) # print("LOSS", loss.item()) - loss.backward(retain_graph=True) + loss.backward() optimizer.step() - epsilon=pow(10, -5) + epsilon = pow(10, -5) # is each of the networks self-replicating? print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}") @@ -261,6 +264,26 @@ def test_sparse_net(): metanet = SparseNetwork(data_dim, depth=3, width=5, out=10) batchx, batchy = next(iter(d)) metanet(batchx) + print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}") + + +def test_sparse_net_sef_train(): + net = SparseNetwork(30, 5, 6, 10) + optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9) + epochs = 120 + + for _ in trange(epochs): + optimizer.zero_grad() + loss = net.combined_self_train() + + loss.backward(retain_graph=True) + optimizer.step() + + # is each of the networks self-replicating? + counter = defaultdict(lambda: 0) + id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) + counter = dict(counter) + print(f"identity_fn after {epochs+1} self-train epochs: {counter}") def test_manual_for_loop(): @@ -284,7 +307,8 @@ def test_manual_for_loop(): if __name__ == '__main__': - test_sparse_layer() + # test_sparse_layer() + test_sparse_net_sef_train() # test_sparse_net() # for comparison - test_manual_for_loop() \ No newline at end of file + # test_manual_for_loop() \ No newline at end of file