diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index 58ff112..7679cec 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -16,7 +16,7 @@ from torch.nn import Flatten from torch.utils.data import Dataset, DataLoader from torchvision.datasets import MNIST from torchvision.transforms import ToTensor, Compose, Resize -from tqdm import tqdm +from tqdm import tqdm, trange # noinspection DuplicatedCode if platform.node() == 'CarbonX': @@ -46,7 +46,7 @@ WORKER = 10 if not debug else 2 debug = False BATCHSIZE = 500 if not debug else 50 EPOCH = 100 -VALIDATION_FRQ = 4 if not debug else 1 +VALIDATION_FRQ = 3 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -292,24 +292,23 @@ if __name__ == '__main__': train_to_task_first = False sequential_task_train = True force_st_for_n_from_last_epochs = 5 - n_st_per_batch = 3 - activation = None # nn.ReLU() + n_st_per_batch = 10 + # activation = None # nn.ReLU() - use_sparse_network = False + use_sparse_network = True - for weight_hidden_size in [8]: + for weight_hidden_size in [4, 5, 6]: tsk_threshold = 0.85 weight_hidden_size = weight_hidden_size residual_skip = True - n_seeds = 3 + n_seeds = 1 data_path = Path('data') data_path.mkdir(exist_ok=True, parents=True) - assert not (train_to_task_first and train_to_id_first) st_str = f'{"" if self_train else "no_"}st{f"_n_{n_st_per_batch}" if n_st_per_batch else ""}' - ac_str = f'_{activation.__class__.__name__}' if activation is not None else '' + # ac_str = f'_{activation.__class__.__name__}' if activation is not None else '' res_str = f'{"" if residual_skip else "_no_res"}' # dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' id_str = f'{f"_StToId" if train_to_id_first else ""}' @@ -318,7 +317,7 @@ if __name__ == '__main__': f_str = f'_f_{force_st_for_n_from_last_epochs}' if \ force_st_for_n_from_last_epochs and sequential_task_train and train_to_task_first else "" config_str = f'{res_str}{id_str}{tsk_str}{f_str}{sprs_str}' - exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}{ac_str}' + exp_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{config_str}' if not training: # noinspection PyRedeclaration @@ -326,10 +325,12 @@ if __name__ == '__main__': for seed in range(n_seeds): seed_path = exp_path / str(seed) + seed_path.mkdir(exist_ok=True, parents=True) model_path = seed_path / '0000_trained_model.zip' df_store_path = seed_path / 'train_store.csv' weight_store_path = seed_path / 'weight_store.csv' + init_st_store_path = seed_path / 'init_st_counter.csv' srnn_parameters = dict() if training: @@ -345,92 +346,139 @@ if __name__ == '__main__': d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) interface = np.prod(dataset[0][0].shape) - dense_metanet = MetaNet(interface, depth=5, width=6, out=10, residual_skip=residual_skip, - weight_hidden_size=weight_hidden_size, activation=activation).to(DEVICE) - sparse_metanet = SparseNetwork(interface, depth=5, width=6, out=10, residual_skip=residual_skip, - weight_hidden_size=weight_hidden_size, activation=activation + dense_metanet = MetaNet(interface, depth=3, width=6, out=10, residual_skip=residual_skip, + weight_hidden_size=weight_hidden_size + ).to(DEVICE) + sparse_metanet = SparseNetwork(interface, depth=3, width=6, out=10, residual_skip=residual_skip, + weight_hidden_size=weight_hidden_size ).to(DEVICE) if use_sparse_network else dense_metanet meta_weight_count = sum(p.numel() for p in next(dense_metanet.particles).parameters()) loss_fn = nn.CrossEntropyLoss() - dense_optimizer = torch.optim.SGD(dense_metanet.parameters(), lr=0.004, momentum=0.9) - sparse_optimizer = torch.optim.SGD( - sparse_metanet.parameters(), lr=0.004, momentum=0.9 - ) if use_sparse_network else dense_optimizer + optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.004, momentum=0.9) train_store = new_storage_df('train', None) weight_store = new_storage_df('weights', meta_weight_count) - init_tsk = train_to_task_first - for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'): + + if train_to_task_first: + dense_metanet = dense_metanet.train() + for epoch in trange(10): + for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='Train - Batch'): + # Task Train + # Zero your gradients for every batch! + optimizer.zero_grad() + batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) + y_pred = dense_metanet(batch_x) + + loss = loss_fn(y_pred, batch_y.to(torch.long)) + loss.backward() + + # Adjust learning weights + optimizer.step() + step_log = dict(Epoch=epoch, Batch=batch, + Metric='Task Loss', Score=loss.item()) + train_store.loc[train_store.shape[0]] = step_log + # Transfer weights + if use_sparse_network: + sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) + + if train_to_id_first: + sparse_metanet = sparse_metanet.train() + init_st_epochs = 1500 + init_st_df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count']) + + for st_epoch in trange(init_st_epochs): + _ = sparse_metanet.combined_self_train(optimizer) + + if st_epoch % 500 == 0: + counter = defaultdict(lambda: 0) + id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles)) + counter = dict(counter) + tqdm.write(f"identity_fn after {st_epoch} self-train epochs: {counter}") + for key, value in counter.items(): + init_st_df.loc[init_st_df.shape[0]] = (st_epoch, key, value) + sparse_metanet.reset_diverged_particles() + counter = defaultdict(lambda: 0) + id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles)) + counter = dict(counter) + tqdm.write(f"identity_fn after {init_st_epochs} self-train epochs: {counter}") + for key, value in counter.items(): + init_st_df.loc[init_st_df.shape[0]] = (init_st_epochs, key, value) + init_st_df.to_csv(init_st_store_path, mode='w', index=False) + + c = pd.read_csv(init_st_store_path) + sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type') + plt.savefig(init_st_store_path.parent / f'{init_st_store_path.stem}.png', dpi=300) + + # Transfer weights + if use_sparse_network: + dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) + + for epoch in trange(EPOCH, desc=f'Train - Epochs'): tqdm.write(f'{seed}: {exp_path}') - is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True - is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True + is_validation_epoch = (epoch % VALIDATION_FRQ == 0) if not debug else True + is_self_train_epoch = (epoch % SELF_TRAIN_FRQ == 0) if not debug else True sparse_metanet = sparse_metanet.train() dense_metanet = dense_metanet.train() if is_validation_epoch: metric = torchmetrics.Accuracy() else: metric = None - init_st = train_to_id_first and not all( - x.is_fixpoint == ft.identity_func for x in dense_metanet.particles - ) - force_st = (force_st_for_n_from_last_epochs >= (EPOCH - epoch) - ) and sequential_task_train and force_st_for_n_from_last_epochs - for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): + + for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='Train - Batch'): # Self Train - if self_train and ((not init_tsk and (is_self_train_epoch or init_st)) or force_st): - # Transfer weights - if use_sparse_network: - sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) + if is_self_train_epoch: for _ in range(n_st_per_batch): - self_train_loss = sparse_metanet.combined_self_train(sparse_optimizer, reduction='mean') + self_train_loss = sparse_metanet.combined_self_train(optimizer) # noinspection PyUnboundLocalVariable step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item()) train_store.loc[train_store.shape[0]] = step_log + # Clean Divergent + sparse_metanet.reset_diverged_particles() # Transfer weights if use_sparse_network: dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights) - dense_metanet.reset_diverged_particles() - # Task Train - if not init_st: - # Zero your gradients for every batch! - dense_optimizer.zero_grad() - batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) - y_pred = dense_metanet(batch_x) - # loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32)) - loss = loss_fn(y_pred, batch_y.to(torch.long)) - loss.backward() + # Zero your gradients for every batch! + optimizer.zero_grad() + batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) + y_pred = dense_metanet(batch_x) - # Adjust learning weights - dense_optimizer.step() + loss = loss_fn(y_pred, batch_y.to(torch.long)) + loss.backward() - step_log = dict(Epoch=epoch, Batch=batch, - Metric='Task Loss', Score=loss.item()) - train_store.loc[train_store.shape[0]] = step_log - if is_validation_epoch: - metric(y_pred.cpu(), batch_y.cpu()) + # Adjust learning weights + optimizer.step() + + # Transfer weights + if use_sparse_network: + sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles) + + step_log = dict(Epoch=epoch, Batch=batch, + Metric='Task Loss', Score=loss.item()) + train_store.loc[train_store.shape[0]] = step_log + if is_validation_epoch: + metric(y_pred.cpu(), batch_y.cpu()) if batch >= 3 and debug: break if is_validation_epoch: dense_metanet = dense_metanet.eval() - if not init_st: - validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, - Metric='Train Accuracy', Score=metric.compute().item()) - train_store.loc[train_store.shape[0]] = validation_log + + validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, + Metric='Train Accuracy', Score=metric.compute().item()) + train_store.loc[train_store.shape[0]] = validation_log accuracy = checkpoint_and_validate(dense_metanet, seed_path, epoch).item() validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric='Test Accuracy', Score=accuracy) train_store.loc[train_store.shape[0]] = validation_log - if init_tsk or (train_to_task_first and sequential_task_train): - init_tsk = accuracy <= tsk_threshold - if init_st or is_validation_epoch: + + if is_validation_epoch: counter_dict = defaultdict(lambda: 0) # This returns ID-functions _ = test_for_fixpoints(counter_dict, list(dense_metanet.particles)) @@ -439,12 +487,14 @@ if __name__ == '__main__': step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) train_store.loc[train_store.shape[0]] = step_log tqdm.write(f'Fixpoint Tester Results: {counter_dict}') - if init_st or is_validation_epoch: + for particle in dense_metanet.particles: weight_log = (epoch, particle.name, *flat_for_store(particle.parameters())) weight_store.loc[weight_store.shape[0]] = weight_log - train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False) - weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False) + train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), + index=False) + weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), + index=False) train_store = new_storage_df('train', None) weight_store = new_storage_df('weights', meta_weight_count) diff --git a/network.py b/network.py index 3db1be6..18799a3 100644 --- a/network.py +++ b/network.py @@ -445,10 +445,12 @@ class MetaNet(nn.Module): tensor = self._meta_layer_first(x) residual = None for idx, meta_layer in enumerate(self._meta_layer_list, start=1): - if idx % 2 == 1 and self.residual_skip: + # if idx % 2 == 1 and self.residual_skip: + if self.residual_skip: residual = tensor tensor = meta_layer(tensor) - if idx % 2 == 0 and self.residual_skip: + # if idx % 2 == 0 and self.residual_skip: + if self.residual_skip: tensor = tensor + residual tensor = self._meta_layer_last(tensor) return tensor diff --git a/sparse_net.py b/sparse_net.py index efca1b9..128a23e 100644 --- a/sparse_net.py +++ b/sparse_net.py @@ -1,25 +1,29 @@ from collections import defaultdict import pandas as pd +from matplotlib import pyplot as plt +import seaborn as sns from torch import nn import functionalities_test from network import Net -from functionalities_test import is_identity_function -from tqdm import tqdm,trange +from functionalities_test import is_identity_function, test_for_fixpoints, epsilon_error_margin +from tqdm import tqdm, trange import numpy as np from pathlib import Path import torch from torch.nn import Flatten from torch.utils.data import DataLoader -import torch.nn.functional as F + from torchvision.datasets import MNIST from torchvision.transforms import ToTensor, Compose, Resize def xavier_init(m): if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight.data) + return nn.init.xavier_uniform_(m.weight.data) + if isinstance(m, torch.Tensor): + return nn.init.xavier_uniform_(m) class SparseLayer(nn.Module): @@ -101,7 +105,9 @@ class SparseLayer(nn.Module): for weights in self.weights: if torch.isinf(weights).any() or torch.isnan(weights).any(): with torch.no_grad(): - xavier_init(weights) + where_nan = torch.nan_to_num(weights, -99, -99, -99) + mask = torch.where(where_nan == -99, 0, 1) + weights[:] = (where_nan * mask + torch.randn_like(weights) * (1 - mask))[:] @property def particle_weights(self): @@ -139,8 +145,9 @@ def test_sparse_layer(): optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9) # optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9) df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count']) + train_iterations = 20000 - for train_iteration in trange(20000): + for train_iteration in trange(train_iterations): optimizer.zero_grad() X, Y = net.get_self_train_inputs_and_targets() output = net(X) @@ -163,12 +170,11 @@ def test_sparse_layer(): counter = defaultdict(lambda: 0) id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) counter = dict(counter) - tqdm.write(f"identity_fn after {train_iteration + 1} self-train epochs: {counter}") + tqdm.write(f"identity_fn after {train_iterations} self-train epochs: {counter}") for key, value in counter.items(): - df.loc[df.shape[0]] = (train_iteration, key, value) + df.loc[df.shape[0]] = (train_iterations, key, value) df.to_csv('counter.csv', mode='w') - import seaborn as sns - import matplotlib.pyplot as plt + c = pd.read_csv('counter.csv', index_col=0) sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type') plt.savefig('counter.png', dpi=300) @@ -191,6 +197,11 @@ def embed_vector(x, repeat_dim): class SparseNetwork(nn.Module): + + @property + def nr_nets(self): + return sum(x.nr_nets for x in self.sparselayers) + def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None, weight_interface=5, weight_hidden_size=2, weight_output_size=1 ): @@ -216,16 +227,13 @@ class SparseNetwork(nn.Module): if self.activation: tensor = self.activation(tensor) for nl_idx, network_layer in enumerate(self.hidden_layers): - # Sparse Layer pass + # if idx % 2 == 1 and self.residual_skip: + if self.residual_skip: + residual = tensor tensor = self.sparse_layer_forward(tensor, network_layer) - - if self.activation: - tensor = self.activation(tensor) - if nl_idx % 2 == 0 and self.residual_skip: - residual = tensor.clone() - if nl_idx % 2 == 1 and self.residual_skip: - # noinspection PyUnboundLocalVariable - tensor += residual + # if idx % 2 == 0 and self.residual_skip: + if self.residual_skip: + tensor = tensor + residual tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim) return tensor @@ -282,7 +290,7 @@ class SparseNetwork(nn.Module): output = layer(x) # loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output) - loss = loss_fn(output, target_data) * 85 + loss = loss_fn(output, target_data) * layer.nr_nets losses.append(loss.detach()) loss.backward() @@ -311,39 +319,42 @@ def test_sparse_net(): data_dim = np.prod(dataset[0][0].shape) metanet = SparseNetwork(data_dim, depth=3, width=5, out=10) batchx, batchy = next(iter(d)) - metanet(batchx) - print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}") + out = metanet(batchx) + + result = sum([torch.allclose(out[i], batchy[i], rtol=0, atol=epsilon_error_margin) for i in range(metanet.nr_nets)]) + # print(f"identity_fn after {train_iteration+1} self-train iterations: {result} /{net.nr_nets}") def test_sparse_net_sef_train(): - net = SparseNetwork(5, 5, 6, 10) - epochs = 10000 - df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count']) - optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) - for epoch in trange(epochs): - _ = net.combined_self_train(optimizer) + sparse_metanet = SparseNetwork(15*15, 5, 6, 10).to('cuda') + init_st_store_path = Path('counter.csv') + optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.004, momentum=0.9) + init_st_epochs = 10000 + init_st_df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count']) - if epoch % 500 == 0: + for st_epoch in trange(init_st_epochs): + _ = sparse_metanet.combined_self_train(optimizer) + + if st_epoch % 500 == 0: counter = defaultdict(lambda: 0) - id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) + id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles)) counter = dict(counter) - tqdm.write(f"identity_fn after {epoch + 1} self-train epochs: {counter}") + tqdm.write(f"identity_fn after {st_epoch} self-train epochs: {counter}") for key, value in counter.items(): - df.loc[df.shape[0]] = (epoch, key, value) - net.reset_diverged_particles() + init_st_df.loc[init_st_df.shape[0]] = (st_epoch, key, value) + sparse_metanet.reset_diverged_particles() counter = defaultdict(lambda: 0) - id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles)) + id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles)) counter = dict(counter) - tqdm.write(f"identity_fn after {epochs} self-train epochs: {counter}") + tqdm.write(f"identity_fn after {init_st_epochs} self-train epochs: {counter}") for key, value in counter.items(): - df.loc[df.shape[0]] = (epoch, key, value) - df.to_csv('counter.csv', mode='w') - import seaborn as sns - import matplotlib.pyplot as plt - c = pd.read_csv('counter.csv', index_col=0) + init_st_df.loc[init_st_df.shape[0]] = (init_st_epochs, key, value) + init_st_df.to_csv(init_st_store_path, mode='w', index=False) + + c = pd.read_csv(init_st_store_path) sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type') - plt.savefig('counter.png', dpi=300) + plt.savefig(init_st_store_path, dpi=300) def test_manual_for_loop(): @@ -353,7 +364,7 @@ def test_manual_for_loop(): rounds = 1000 for net in tqdm(nets): - optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9) + optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9) for i in range(rounds): optimizer.zero_grad() input_data = net.input_weight_matrix()