From 1b7581e6566e320ccf3e154db8e51b6f4a105025 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 1 Feb 2022 18:17:11 +0100 Subject: [PATCH] MetaNetworks Debugged II --- experiments/__init__.py | 3 +- experiments/meta_task_exp.py | 114 +++++++++++++++++++++-------------- functionalities_test.py | 24 ++++---- network.py | 25 +++++++- 4 files changed, 105 insertions(+), 61 deletions(-) diff --git a/experiments/__init__.py b/experiments/__init__.py index 7a4e479..70ff1e0 100644 --- a/experiments/__init__.py +++ b/experiments/__init__.py @@ -2,4 +2,5 @@ from .mixed_setting_exp import run_mixed_experiment from .robustness_exp import run_robustness_experiment from .self_application_exp import run_SA_experiment from .self_train_exp import run_ST_experiment -from .soup_exp import run_soup_experiment \ No newline at end of file +from .soup_exp import run_soup_experiment +import functionalities_test \ No newline at end of file diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index af2cda5..bc5de87 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -6,8 +6,16 @@ import platform import pandas as pd import torchmetrics - -from functionalities_test import test_for_fixpoints +import numpy as np +import torch +from matplotlib import pyplot as plt +import seaborn as sns +from torch import nn +from torch.nn import Flatten +from torch.utils.data import Dataset, DataLoader +from torchvision.datasets import MNIST +from torchvision.transforms import ToTensor, Compose, Resize +from tqdm import tqdm if platform.node() == 'CarbonX': debug = True @@ -28,23 +36,12 @@ else: DIR = None pass - -import numpy as np -import torch -from matplotlib import pyplot as plt -import seaborn as sns -from torch import nn -from torch.nn import Flatten -from torch.utils.data import Dataset, DataLoader -from torchvision.datasets import MNIST -from torchvision.transforms import ToTensor, Compose, Resize -from tqdm import tqdm - from network import MetaNet +from functionalities_test import test_for_fixpoints WORKER = 10 if not debug else 2 BATCHSIZE = 500 if not debug else 50 -EPOCH = 50 if not debug else 3 +EPOCH = 100 if not debug else 3 VALIDATION_FRQ = 5 if not debug else 1 SELF_TRAIN_FRQ = 1 if not debug else 1 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -78,7 +75,7 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False): if not final_model: ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp' else: - ckpt_path = Path(out_path) / f'trained_model_ckpt.tp' + ckpt_path = Path(out_path) / f'trained_model_ckpt_e{epoch_n}.tp' ckpt_path.parent.mkdir(exist_ok=True, parents=True) torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) @@ -91,15 +88,16 @@ def validate(checkpoint_path, ratio=0.1): # initialize metric validmetric = torchmetrics.Accuracy() + ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) try: - datas = MNIST(str(data_path), transform=utility_transforms, train=False) + datas = MNIST(str(data_path), transform=ut, train=False) except RuntimeError: - datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True) + datas = MNIST(str(data_path), transform=ut, train=False, download=True) valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER) model = torch.load(checkpoint_path, map_location=DEVICE).eval() - n_samples = int(len(d) * ratio) + n_samples = int(len(valid_d) * ratio) with tqdm(total=n_samples, desc='Validation Run: ') as pbar: for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d): @@ -119,6 +117,10 @@ def validate(checkpoint_path, ratio=0.1): return acc +def new_train_storage_df(): + return pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score']) + + def checkpoint_and_validate(model, out_path, epoch_n, final_model=False): out_path = Path(out_path) ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model) @@ -130,18 +132,28 @@ def plot_training_result(path_to_dataframe): # load from Drive df = pd.read_csv(path_to_dataframe, index_col=0) + # Set up figure fig, ax1 = plt.subplots() # initializes figure and plots - ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis. + ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis. - # plots the first set of data, and sets it to ax1. - data = df[df['Metric'] == 'BatchLoss'] - # plots the second set, and sets to ax2. - sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue') + # plots the first set of data + data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean() + palette = sns.color_palette()[0:data.reset_index()['Metric'].unique().shape[0]] + sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric', + palette=palette, ax=ax1) + + # plots the second set of data data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')] - sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True) + palette = sns.color_palette()[len(palette):data.reset_index()['Metric'].unique().shape[0] + len(palette)] + sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette) - ax1.set(yscale='log') + ax1.set(yscale='log', ylabel='Losses') ax1.set_title('Training Lineplot') + ax2.set(ylabel='Accuracy') + + fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5)) + ax1.get_legend().remove() + ax2.get_legend().remove() plt.tight_layout() if debug: plt.show() @@ -155,16 +167,17 @@ if __name__ == '__main__': training = False plotting = False particle_analysis = True + as_sparse_network_test = True data_path = Path('data') data_path.mkdir(exist_ok=True, parents=True) - run_path = Path('output') / 'mnist_test_half_size' + run_path = Path('output') / 'mnist_self_train_100_NEW_STYLE' model_path = run_path / '0000_trained_model.zip' + df_store_path = run_path / 'train_store.csv' if training: utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) - try: dataset = MNIST(str(data_path), transform=utility_transforms) except RuntimeError: @@ -177,7 +190,7 @@ if __name__ == '__main__': loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9) - train_store = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score']) + train_store = new_train_storage_df() for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'): is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True @@ -187,12 +200,9 @@ if __name__ == '__main__': metric = None for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): if self_train and is_self_train_epoch: - # Zero your gradients for every batch! - optimizer.zero_grad() - combined_self_train_loss = metanet.combined_self_train() - combined_self_train_loss.backward() - # Adjust learning weights - optimizer.step() + self_train_loss = metanet.combined_self_train(optimizer) + step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item()) + train_store.loc[train_store.shape[0]] = step_log # Zero your gradients for every batch! optimizer.zero_grad() @@ -206,7 +216,7 @@ if __name__ == '__main__': optimizer.step() step_log = dict(Epoch=epoch, Batch=batch, - Metric='BatchLoss', Score=loss.item()) + Metric='Task Loss', Score=loss.item()) train_store.loc[train_store.shape[0]] = step_log if is_validation_epoch: metric(y.cpu(), batch_y.cpu()) @@ -223,23 +233,39 @@ if __name__ == '__main__': validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric='Test Accuracy', Score=accuracy.item()) train_store.loc[train_store.shape[0]] = validation_log + if particle_analysis: + counter_dict = defaultdict(lambda: 0) + # This returns ID-functions + _ = test_for_fixpoints(counter_dict, list(metanet.particles)) + for key, value in dict(counter_dict).items(): + step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) + train_store.loc[train_store.shape[0]] = step_log + train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists()) + train_store = new_train_storage_df() accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True) validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE, Metric='Test Accuracy', Score=accuracy.item()) train_store.loc[train_store.shape[0]] = validation_log - train_store.to_csv(run_path / 'train_store.csv') + train_store.to_csv(df_store_path) if plotting: - plot_training_result(run_path / 'train_store.csv') + plot_training_result(df_store_path) if particle_analysis: - model_path = next(run_path.glob('*.tp')) + model_path = next(run_path.glob('*ckpt.tp')) latest_model = torch.load(model_path, map_location=DEVICE).eval() - analysis_dict = defaultdict(dict) counter_dict = defaultdict(lambda: 0) - for particle in latest_model.particles: - analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged() - test_for_fixpoints(counter_dict, latest_model.particles) - + _ = test_for_fixpoints(counter_dict, list(latest_model.particles)) + tqdm.write(str(dict(counter_dict))) + zero_ident = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('identity_func') + zero_other = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('other_func') + if as_sparse_network_test: + acc_pre = validate(model_path, ratio=1) + ident_ckpt = set_checkpoint(zero_ident, model_path.parent, -1, final_model=True) + ident_acc_post = validate(ident_ckpt, ratio=1) + tqdm.write(f'Zero_ident diff = {abs(ident_acc_post-acc_pre)}') + other_ckpt = set_checkpoint(zero_other, model_path.parent, -2, final_model=True) + other_acc_post = validate(other_ckpt, ratio=1) + tqdm.write(f'Zero_other diff = {abs(other_acc_post - acc_pre)}') diff --git a/functionalities_test.py b/functionalities_test.py index efc2c5d..9c351a3 100644 --- a/functionalities_test.py +++ b/functionalities_test.py @@ -1,16 +1,13 @@ import copy from typing import Dict, List -import numpy as np +import torch +from tqdm import tqdm + from network import Net def is_divergent(network: Net) -> bool: - for i in network.input_weight_matrix(): - weight_value = i[0].item() - - if np.isnan(weight_value).any() or np.isinf(weight_value).any(): - return True - return False + return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item() def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool: @@ -19,13 +16,14 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool: target_data = network.create_target_weights(input_data) predicted_values = network(input_data) - return np.allclose(target_data.detach().numpy(), predicted_values.detach().numpy(), - rtol=0, atol=epsilon) + + return torch.allclose(target_data.detach(), predicted_values.detach(), + rtol=0, atol=epsilon) def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool: target_data = network.create_target_weights(network.input_weight_matrix().detach()) - result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon) + result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon) # result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix())))) return result @@ -49,15 +47,15 @@ def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool: second_output = network(input_data_2) # Perform the Check: all(epsilon > abs(input_data - second_output)) - check_abs_within_epsilon = np.allclose(target_data.detach().numpy(), second_output.detach().numpy(), - rtol=0, atol=epsilon) + check_abs_within_epsilon = torch.allclose(target_data.detach(), second_output.detach(), + rtol=0, atol=epsilon) return check_abs_within_epsilon def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None): id_functions = id_functions or list() - for net in nets: + for net in tqdm(nets, desc='Fixpoint Tester', total=len(nets)): if is_divergent(net): fixpoint_counter["divergent"] += 1 net.is_fixpoint = "divergent" diff --git a/network.py b/network.py index a27b64a..8f82333 100644 --- a/network.py +++ b/network.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import optim, Tensor +from tqdm import tqdm def prng(): @@ -391,6 +392,17 @@ class MetaNet(nn.Module): interface=self.width, width=self.out) ) + def replace_with_zero(self, ident_key): + replaced_particles = 0 + for particle in self.particles: + if particle.is_fixpoint == ident_key: + particle.load_state_dict( + {key: torch.zeros_like(state) for key, state in particle.state_dict().items()} + ) + replaced_particles += 1 + tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}') + return self + def forward(self, x): tensor = x for meta_layer in self._meta_layer_list: @@ -401,15 +413,22 @@ class MetaNet(nn.Module): def particles(self): return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles) - def combined_self_train(self): + def combined_self_train(self, external_optimizer): losses = [] for particle in self.particles: + # Zero your gradients for every batch! + external_optimizer.zero_grad() # Intergrate optimizer and backward function input_data = particle.input_weight_matrix() target_data = particle.create_target_weights(input_data) output = particle(input_data) - losses.append(F.mse_loss(output, target_data)) - return torch.hstack(losses).sum(dim=-1, keepdim=True) + loss = F.mse_loss(output, target_data) + losses.append(loss.detach) + loss.backward() + # Adjust learning weights + external_optimizer.step() + # return torch.hstack(losses).sum(dim=-1, keepdim=True) + return sum(losses) if __name__ == '__main__':