diff --git a/README.md b/README.md index 39d4967..30fd865 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Data Exchange: [Google Drive Folder](***REMOVED***) ### Fixpoint Tests: -- [ ] Dropout Test +- [X] Dropout Test - (Macht das Partikel beim Goal mit oder ist es nur SRN) - Zero_ident diff = -00.04999637603759766 % @@ -29,6 +29,8 @@ Data Exchange: [Google Drive Folder](***REMOVED***) - gits das schon? - Hypernetwork? - arxiv: 1905.02898 + - Sparse Networks + - Pruning --- @@ -42,6 +44,16 @@ Data Exchange: [Google Drive Folder](***REMOVED***) | ![](./figures/sanity/sanity_3hidden_xtimesn.png) | ![](./figures/sanity/sanity_4hidden_xtimesn.png) | | SRNN x*n 6 Neurons Other_Func | SRNN x*n 10 Neurons Other_Func | | ![](./figures/sanity/sanity_6hidden_xtimesn.png) | ![](./figures/sanity/sanity_10hidden_xtimesn.png) | + +- [ ] Connectivity + - Das Netz dünnt sich wirklich aus. + + ||| + |---------------------------------------------------|----------------------------------------------------| + | 200 Epochs - 4 Neurons - \alpha 100 RES | | + | ![](./figures/connectivity/training_lineplot.png) | ![](./figures/connectivity/training_particle_type_lp.png) | + | OTHER FUNTIONS | IDENTITY FUNCTIONS | + | ![](./figures/connectivity/other.png) | ![](./figures/connectivity/identity.png) | - [ ] Training mit kleineren GNs @@ -59,6 +71,7 @@ Data Exchange: [Google Drive Folder](***REMOVED***) - [ ] Test mit Baseline Dense Network - [ ] mit vergleichbaren Neuron Count - [ ] mit gesamt Weight Count + - [ ] Task/Goal statt SRNN-Task --- diff --git a/as_line_plot.py b/as_line_plot.py new file mode 100644 index 0000000..37a5c80 --- /dev/null +++ b/as_line_plot.py @@ -0,0 +1,40 @@ +import numpy as np +import torch +import pandas as pd +import re +from pathlib import Path +import seaborn as sns +from matplotlib import pyplot as plt +from network import FixTypes + + +if __name__ == '__main__': + p = Path(r'experiments\output\mn_st_200_4_alpha_100\trained_model_ckpt_e200.tp') + m = torch.load(p, map_location=torch.device('cpu')) + particles = [y for x in m._meta_layer_list for y in x.particles] + df = pd.DataFrame(columns=['type', 'layer', 'neuron', 'name', 'color']) + colors = [] + + for particle in particles: + l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", particle.name).split('_')] + + color = sns.color_palette()[0 if particle.is_fixpoint == FixTypes.identity_func else 1] + # color = 'orange' if particle.is_fixpoint == FixTypes.identity_func else 'blue' + colors.append(color) + df.loc[df.shape[0]] = (particle.is_fixpoint, l-1, w, particle.name, color) + df.loc[df.shape[0]] = (particle.is_fixpoint, l, c, particle.name, color) + for layer in list(df['layer'].unique()): + divisor = df.loc[(df['layer'] == layer), 'neuron'].max() + df.loc[(df['layer'] == layer), 'neuron'] /= divisor + + print('gathered') + for n, (fixtype, color) in enumerate(zip([FixTypes.other_func, FixTypes.identity_func], ['blue', 'orange'])): + plt.clf() + ax = sns.lineplot(y='neuron', x='layer', hue='name', data=df[df['type'] == fixtype], + legend=False, estimator=None, + palette=[sns.color_palette()[n]] * (df[df['type'] == fixtype].shape[0]//2), lw=1) + # ax.set(yscale='log', ylabel='Neuron') + ax.set_title(fixtype) + plt.show() + print('plottet') + diff --git a/experiments/meta_task_exp.py b/experiments/meta_task_exp.py index 3af066e..00b7c57 100644 --- a/experiments/meta_task_exp.py +++ b/experiments/meta_task_exp.py @@ -17,6 +17,7 @@ from torchvision.datasets import MNIST from torchvision.transforms import ToTensor, Compose, Resize from tqdm import tqdm + if platform.node() == 'CarbonX': debug = True print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") @@ -36,8 +37,8 @@ else: DIR = None pass -from network import MetaNet -from functionalities_test import test_for_fixpoints, FixTypes +from network import MetaNet, FixTypes +from functionalities_test import test_for_fixpoints WORKER = 10 if not debug else 2 debug = False @@ -195,13 +196,14 @@ def flat_for_store(parameters): if __name__ == '__main__': self_train = True - training = True + training = False plotting = True particle_analysis = True as_sparse_network_test = True - self_train_alpha = 1 + train_to_id_first = False + self_train_alpha = 100 batch_train_beta = 1 - weight_hidden_size = 5 + weight_hidden_size = 4 residual_skip = True dropout = 0 @@ -209,9 +211,11 @@ if __name__ == '__main__': data_path.mkdir(exist_ok=True, parents=True) st_str = f'{"" if self_train else "no_"}st' + a_str = f'_alpha_{self_train_alpha}' if self_train_alpha != 1 else '' res_str = f'{"" if residual_skip else "_no"}_res' dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}' - run_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{res_str}{dr_str}' + id_str = f'{f"_StToId" if train_to_id_first else ""}' + run_path = Path('output') / f'mn_{st_str}_{EPOCH}_{weight_hidden_size}{a_str}{res_str}{dr_str}{id_str}' model_path = run_path / '0000_trained_model.zip' df_store_path = run_path / 'train_store.csv' @@ -245,8 +249,9 @@ if __name__ == '__main__': metric = torchmetrics.Accuracy() else: metric = None + init_st = train_to_id_first and all(x.is_fixpoint == FixTypes.identity_func for x in metanet.particles) for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'): - if self_train and is_self_train_epoch: + if (self_train and is_self_train_epoch) or init_st: # Zero your gradients for every batch! optimizer.zero_grad() self_train_loss = metanet.combined_self_train() * self_train_alpha @@ -255,44 +260,46 @@ if __name__ == '__main__': optimizer.step() step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item()) train_store.loc[train_store.shape[0]] = step_log + if train_to_id_first <= epoch: + # Zero your gradients for every batch! + optimizer.zero_grad() + batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) + y = metanet(batch_x) + # loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32)) + loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta + loss.backward() - # Zero your gradients for every batch! - optimizer.zero_grad() - batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE) - y = metanet(batch_x) - # loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32)) - loss = loss_fn(y, batch_y.to(torch.long)) * batch_train_beta - loss.backward() + # Adjust learning weights + optimizer.step() - # Adjust learning weights - optimizer.step() - - step_log = dict(Epoch=epoch, Batch=batch, - Metric='Task Loss', Score=loss.item()) - train_store.loc[train_store.shape[0]] = step_log - if is_validation_epoch: - metric(y.cpu(), batch_y.cpu()) + step_log = dict(Epoch=epoch, Batch=batch, + Metric='Task Loss', Score=loss.item()) + train_store.loc[train_store.shape[0]] = step_log + if is_validation_epoch: + metric(y.cpu(), batch_y.cpu()) if batch >= 3 and debug: break if is_validation_epoch: metanet = metanet.eval() - validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, - Metric='Train Accuracy', Score=metric.compute().item()) - train_store.loc[train_store.shape[0]] = validation_log + if train_to_id_first <= epoch: + validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, + Metric='Train Accuracy', Score=metric.compute().item()) + train_store.loc[train_store.shape[0]] = validation_log accuracy = checkpoint_and_validate(metanet, run_path, epoch) validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric='Test Accuracy', Score=accuracy.item()) train_store.loc[train_store.shape[0]] = validation_log - if particle_analysis: - counter_dict = defaultdict(lambda: 0) - # This returns ID-functions - _ = test_for_fixpoints(counter_dict, list(metanet.particles)) - for key, value in dict(counter_dict).items(): - step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) - train_store.loc[train_store.shape[0]] = step_log + if particle_analysis and (init_st or is_validation_epoch): + counter_dict = defaultdict(lambda: 0) + # This returns ID-functions + _ = test_for_fixpoints(counter_dict, list(metanet.particles)) + for key, value in dict(counter_dict).items(): + step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value) + train_store.loc[train_store.shape[0]] = step_log + if init_st or is_validation_epoch: for particle in metanet.particles: weight_log = (epoch, particle.name, *flat_for_store(particle.parameters())) weight_store.loc[weight_store.shape[0]] = weight_log @@ -355,7 +362,7 @@ if __name__ == '__main__': fig, ax = plt.subplots(ncols=2) labels = ['Full Network', 'Sparse, No Identity', 'Sparse, No Other'] colors = sns.color_palette()[:diff_df.shape[0]] if diff_df.shape[0] >= 2 else sns.color_palette()[0] - barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', color=colors, ax=ax[0]) + barplot = sns.barplot(data=diff_df, y='Accuracy', x='Particle Type', palette=colors, ax=ax[0]) # noinspection PyUnboundLocalVariable for idx, patch in enumerate(barplot.patches): if idx != 0: @@ -366,7 +373,7 @@ if __name__ == '__main__': ax[0].set_xlabel('Accuracy') # ax[0].legend() - ax[1].pie(counter_dict.values(), labels=counter_dict.keys(), colors=sns.color_palette()[:3], ) + ax[1].pie(counter_dict.values(), labels=counter_dict.keys(), colors=colors, ) ax[1].set_title('Particle Count for ') # ax[1].set_xlabel('') diff --git a/figures/connectivity/identity.png b/figures/connectivity/identity.png new file mode 100644 index 0000000..0745a92 Binary files /dev/null and b/figures/connectivity/identity.png differ diff --git a/figures/connectivity/other.png b/figures/connectivity/other.png new file mode 100644 index 0000000..b80ba17 Binary files /dev/null and b/figures/connectivity/other.png differ diff --git a/figures/connectivity/training_lineplot.png b/figures/connectivity/training_lineplot.png new file mode 100644 index 0000000..9c07b36 Binary files /dev/null and b/figures/connectivity/training_lineplot.png differ diff --git a/figures/connectivity/training_particle_type_lp.png b/figures/connectivity/training_particle_type_lp.png new file mode 100644 index 0000000..08f79c1 Binary files /dev/null and b/figures/connectivity/training_particle_type_lp.png differ diff --git a/functionalities_test.py b/functionalities_test.py index 093fca2..7165e30 100644 --- a/functionalities_test.py +++ b/functionalities_test.py @@ -3,20 +3,7 @@ from typing import Dict, List import torch from tqdm import tqdm -from network import Net - - -class FixTypes: - - divergent = 'divergent' - fix_zero = 'fix_zero' - identity_func = 'identity_func' - fix_sec = 'fix_sec' - other_func = 'other_func' - - @classmethod - def all_types(cls): - return [val for key, val in cls.__dict__.items() if isinstance(val, str) and not key.startswith('_')] +from network import FixTypes, Net def is_divergent(network: Net) -> bool: diff --git a/helpers.py b/helpers.py new file mode 100644 index 0000000..e69de29 diff --git a/network.py b/network.py index f1a0b88..3cf0215 100644 --- a/network.py +++ b/network.py @@ -15,6 +15,18 @@ from tqdm import tqdm def prng(): return random.random() +class FixTypes: + + divergent = 'divergent' + fix_zero = 'fix_zero' + identity_func = 'identity_func' + fix_sec = 'fix_sec' + other_func = 'other_func' + + @classmethod + def all_types(cls): + return [val for key, val in cls.__dict__.items() if isinstance(val, str) and not key.startswith('_')] + class Net(nn.Module): @@ -79,7 +91,7 @@ class Net(nn.Module): self.trained = False self.number_trained = 0 - self.is_fixpoint = "" + self.is_fixpoint = FixTypes.other_func self.layers = nn.ModuleList( [nn.Linear(i_size, h_size, False), nn.Linear(h_size, h_size, False),