Compare commits
46 Commits
Author | SHA1 | Date | |
---|---|---|---|
38a4af2baa | |||
0ba3994325 | |||
dd2458da4a | |||
ce5a36c8f4 | |||
b3d4987cb8 | |||
f3ff4c9239 | |||
69c904e156 | |||
b6c8859081 | |||
c52b398819 | |||
16c08d04d4 | |||
e167cc78c5 | |||
926b27b4ef | |||
78a919395b | |||
a3a587476c | |||
c0db8e19a3 | |||
e1a5383c04 | |||
9d8496a725 | |||
5b2b5b5beb | |||
3da00c793b | |||
ebf133414c | |||
0bc3b62340 | |||
f0ad875e79 | |||
bb12176f72 | |||
78e9c4d520 | |||
7166fa86ca | |||
f033a2448e | |||
2a710b40d7 | |||
f25cee5203 | |||
52081d176e | |||
5aea4f9f55 | |||
4f99251e68 | |||
28b5e85697 | |||
a4d1ee86dd | |||
62e640e1f0 | |||
8546cc7ddf | |||
14768ffc0a | |||
7ae3e96ec9 | |||
594bbaa3dd | |||
d4c25872c6 | |||
3746c7d26e | |||
df182d652a | |||
21c3e75177 | |||
382afa8642 | |||
6b1efd0c49 | |||
3fe4f49bca | |||
eb3b9b8958 |
108
README.md
@ -1,52 +1,84 @@
|
||||
# self-rep NN paper - ALIFE journal edition
|
||||
# Bureaucratic Cohort Swarms
|
||||
### Pruning Networks by SRNN
|
||||
###### Deadline: 28.02.22
|
||||
|
||||
- [x] Plateau / Pillar sizeWhat does happen to the fixpoints after noise introduction and retraining?Options beeing: Same Fixpoint, Similar Fixpoint (Basin),
|
||||
- Different Fixpoint?
|
||||
Yes, we did not found same (10-5)
|
||||
- Do they do the clustering thingy?
|
||||
Kind of: Small movement towards (MIM-Distance getting smaller) parent fixpoint.
|
||||
Small movement for everyone? -> Distribution
|
||||
## Experimente
|
||||
|
||||
- see `journal_basins.py` for the "train -> spawn with noise -> train again and see where they end up" functionality. Apply noise follows the `vary` function that was used in the paper robustness test with `+- prng() * eps`. Change if desired.
|
||||
### Fixpoint Tests:
|
||||
|
||||
- there is also a distance matrix for all-to-all particle comparisons (with distance parameter one of: `MSE`, `MAE` (mean absolute error = mean manhattan) and `MIM` (mean position invariant manhattan))
|
||||
- [X] Dropout Test
|
||||
- (Macht das Partikel beim Goal mit oder ist es nur SRN)
|
||||
- Zero_ident diff = -00.04999637603759766 %
|
||||
|
||||
- [ ] gnf(1) -> Aprox. Weight
|
||||
- Übersetung in ein Gewichtsskalar
|
||||
- Einbettung in ein Reguläres Netz
|
||||
|
||||
- [ ] Same Thing with Soup interaction. We would expect the same behaviour...Influence of interaction with near and far away particles.
|
||||
-
|
||||
-
|
||||
- [ ] Übersetzung in ein Explainable AI Framework
|
||||
- Rückschlüsse auf Mikro Netze
|
||||
|
||||
- [x] Robustness test with a trained NetworkTraining for high quality fixpoints, compare with the "perfect" fixpoint. Average Loss per application step
|
||||
- [ ] Visualiserung
|
||||
- Der Zugehörigkeit
|
||||
- Der Vernetzung
|
||||
|
||||
- see `journal_robustness.py` for robustness test modeled after cristians robustness-exp (with the exeption that we put noise on the weights). Has `synthetic` bool to switch to hand-modeled perfect fixpoint instead of naturally trained ones.
|
||||
- [ ] PCA()
|
||||
- Dataframe Epoch, Weight, dim_1, ..., dim_n
|
||||
- Visualisierung als Trajectory Cube
|
||||
|
||||
- Also added two difference between the "time-as-fixpoint" and "time-to-verge" (i.e. to divergence / zero).
|
||||
|
||||
- We might need to consult about the "average loss per application step", as I think application loss get gradually higher the worse the weights get. So the average might not tell us much here.
|
||||
|
||||
- [x] Adjust Self Training so that it favors second order fixpoints-> Second order test implementation (?)
|
||||
|
||||
- [x] Barplot over clones -> how many become a fixpoint cs how many diverge per noise level
|
||||
|
||||
- [x] Box-Plot of Avg. Distance of clones from parent
|
||||
|
||||
- [x] Search subspace between two fixpoints by linage(10**-5), check were they end up
|
||||
|
||||
- [x] How are basins / "attractor areas" shaped?
|
||||
|
||||
|
||||
# Future Todos:
|
||||
|
||||
- [ ] Find a statistik over weight space that provides a better init function
|
||||
- [ ] Test this init function on a mnist classifier - just for the lolz
|
||||
- [ ] Recherche zu Makro Mikro Netze Strukturen
|
||||
- gits das schon?
|
||||
- Hypernetwork?
|
||||
- arxiv: 1905.02898
|
||||
- Sparse Networks
|
||||
- Pruning
|
||||
|
||||
---
|
||||
## Notes:
|
||||
|
||||
- In the spawn-experiment we now fit and transform the PCA over *ALL* trajectories, instead of each net-history by its own. This can be toggled by the `plot_pca_together` parameter in `visualisation.py/plot_3d_self_train() & plot_3d()` (default: `False` but set `True` in the spawn-experiment class).
|
||||
### Tasks für Steffen:
|
||||
- [x] Sanity Check:
|
||||
|
||||
- I have also added a `start_time` property for the nets (default: `1`). This is intended to be set flexibly for e.g., clones (when they are spawned midway through the experiment), such that the PCA can start the plotting trace from this timestep. When we spawn clones we deepcopy their parent's saved weight_history too, so that the PCA transforms same lenght trajectories. With `plot_pca_together` that means that clones and their parents will literally be plotted perfectly overlayed on top, up until the spawn-time, where you can see the offset / noise we apply. By setting the start_time, you can avoid this overlap and avoid hiding the parent's trace color which gets plotted first (because the parent is always added to self.nets first). **But more importantly, you can effectively zoom into the plot, by setting the parents start-time to just shy of the end of first epoch (where they get checked on fixpoint-property and spawn clones) and the start-times of clones to the second epoch. This will make the plot begin at spawn time, cutting off the parents initial trajectory and zoom-in to the action (see. `journal_basins.py/spawn_and_continue()`).**
|
||||
- [x] Neuronen können lernen einen Eingabewert mit x zu multiplizieren?
|
||||
|
||||
- Now saving the whole experiment class as pickle dump (`experiment_pickle.p`, just like cristian), hope thats fine.
|
||||
| SRNN x*n 3 Neurons Identity_Func | SRNN x*n 4 Neurons Identity_Func |
|
||||
|---------------------------------------------------|----------------------------------------------------|
|
||||
|  |  |
|
||||
| SRNN x*n 6 Neurons Other_Func | SRNN x*n 10 Neurons Other_Func |
|
||||
|  |  |
|
||||
|
||||
- [ ] Connectivity
|
||||
- Das Netz dünnt sich wirklich aus.
|
||||
|
||||
|||
|
||||
|---------------------------------------------------|----------------------------------------------------|
|
||||
| 200 Epochs - 4 Neurons - \alpha 100 RES | |
|
||||
|  |  |
|
||||
| OTHER FUNTIONS | IDENTITY FUNCTIONS |
|
||||
|  |  |
|
||||
|
||||
- [ ] Training mit kleineren GNs
|
||||
|
||||
|
||||
- [ ] Weiter Trainieren -> 500 Epochs?
|
||||
- [x] Training ohne Residual Skip Connection
|
||||
- Ist anders:
|
||||
Self Training wird zunächst priorisiert, dann kommt langsam der eigentliche Task durch:
|
||||
|
||||
| No Residual Skip connections 8 Neurons in SRNN Alpha=100 | Residual Skip connections 8 Neurons in SRNN Alpha=100 |
|
||||
|------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------|
|
||||
|  |  |
|
||||
|  |  |
|
||||
|
||||
- [ ] Test mit Baseline Dense Network
|
||||
- [ ] mit vergleichbaren Neuron Count
|
||||
- [ ] mit gesamt Weight Count
|
||||
|
||||
- [ ] Task/Goal statt SRNN-Task
|
||||
|
||||
---
|
||||
|
||||
### Für Menschen mit zu viel Zeit:
|
||||
- [ ] Sparse Network Training der Self Replication
|
||||
- Just for the lulz and speeeeeeed)
|
||||
- (Spaß bei Seite, wäre wichtig für schnellere Forschung)
|
||||
<https://pytorch.org/docs/stable/sparse.html>
|
||||
|
||||
- Added a `requirement.txt` for quick venv / pip -r installs. Append as necessary.
|
||||
|
@ -1,6 +0,0 @@
|
||||
from .mixed_setting_exp import run_mixed_experiment
|
||||
from .robustness_exp import run_robustness_experiment
|
||||
from .self_application_exp import run_SA_experiment
|
||||
from .self_train_exp import run_ST_experiment
|
||||
from .soup_exp import run_soup_experiment
|
||||
import functionalities_test
|
@ -1,271 +0,0 @@
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import platform
|
||||
|
||||
import pandas as pd
|
||||
import torchmetrics
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
from torch import nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
from network import MetaNet
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
WORKER = 10 if not debug else 2
|
||||
BATCHSIZE = 500 if not debug else 50
|
||||
EPOCH = 100 if not debug else 3
|
||||
VALIDATION_FRQ = 5 if not debug else 1
|
||||
SELF_TRAIN_FRQ = 1 if not debug else 1
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
if debug:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
|
||||
class ToFloat:
|
||||
|
||||
def __call__(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
||||
class AddTaskDataset(Dataset):
|
||||
def __init__(self, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.prng = np.random.default_rng()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = self.prng.normal(size=(2,)).astype(np.float32)
|
||||
return ab, ab.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
epoch_n = str(epoch_n)
|
||||
if not final_model:
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / f'trained_model_ckpt_e{epoch_n}.tp'
|
||||
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def validate(checkpoint_path, ratio=0.1):
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
import torchmetrics
|
||||
|
||||
# initialize metric
|
||||
validmetric = torchmetrics.Accuracy()
|
||||
ut = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||
|
||||
try:
|
||||
datas = MNIST(str(data_path), transform=ut, train=False)
|
||||
except RuntimeError:
|
||||
datas = MNIST(str(data_path), transform=ut, train=False, download=True)
|
||||
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
||||
n_samples = int(len(valid_d) * ratio)
|
||||
|
||||
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
|
||||
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
|
||||
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
|
||||
y_valid = model(valid_batch_x)
|
||||
|
||||
# metric on current batch
|
||||
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
if idx == n_samples:
|
||||
break
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
acc = validmetric.compute()
|
||||
tqdm.write(f"Avg. accuracy on all data: {acc}")
|
||||
return acc
|
||||
|
||||
|
||||
def new_train_storage_df():
|
||||
return pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
|
||||
|
||||
def checkpoint_and_validate(model, out_path, epoch_n, final_model=False):
|
||||
out_path = Path(out_path)
|
||||
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
|
||||
result = validate(ckpt_path)
|
||||
return result
|
||||
|
||||
|
||||
def plot_training_result(path_to_dataframe):
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=0)
|
||||
|
||||
# Set up figure
|
||||
fig, ax1 = plt.subplots() # initializes figure and plots
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
|
||||
|
||||
# plots the first set of data
|
||||
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
|
||||
palette = sns.color_palette()[0:data.reset_index()['Metric'].unique().shape[0]]
|
||||
sns.lineplot(data=data.groupby(['Epoch', 'Metric']).mean(), x='Epoch', y='Score', hue='Metric',
|
||||
palette=palette, ax=ax1)
|
||||
|
||||
# plots the second set of data
|
||||
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
|
||||
palette = sns.color_palette()[len(palette):data.reset_index()['Metric'].unique().shape[0] + len(palette)]
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', palette=palette)
|
||||
|
||||
ax1.set(yscale='log', ylabel='Losses')
|
||||
ax1.set_title('Training Lineplot')
|
||||
ax2.set(ylabel='Accuracy')
|
||||
|
||||
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
|
||||
ax1.get_legend().remove()
|
||||
ax2.get_legend().remove()
|
||||
plt.tight_layout()
|
||||
if debug:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
self_train = False
|
||||
training = False
|
||||
plotting = False
|
||||
particle_analysis = True
|
||||
as_sparse_network_test = True
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
run_path = Path('output') / 'mnist_self_train_100_NEW_STYLE'
|
||||
model_path = run_path / '0000_trained_model.zip'
|
||||
df_store_path = run_path / 'train_store.csv'
|
||||
|
||||
if training:
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||
try:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||
except RuntimeError:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=4, width=6, out=10).to(DEVICE).train()
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_store = new_train_storage_df()
|
||||
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
|
||||
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
||||
if is_validation_epoch:
|
||||
metric = torchmetrics.Accuracy()
|
||||
else:
|
||||
metric = None
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||
if self_train and is_self_train_epoch:
|
||||
self_train_loss = metanet.combined_self_train(optimizer)
|
||||
step_log = dict(Epoch=epoch, Batch=batch, Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = metanet(batch_x)
|
||||
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
|
||||
loss = loss_fn(y, batch_y.to(torch.long))
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
step_log = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='Task Loss', Score=loss.item())
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
if is_validation_epoch:
|
||||
metric(y.cpu(), batch_y.cpu())
|
||||
|
||||
if batch >= 3 and debug:
|
||||
break
|
||||
|
||||
if is_validation_epoch:
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric='Train Accuracy', Score=metric.compute().item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
accuracy = checkpoint_and_validate(metanet, run_path, epoch)
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric='Test Accuracy', Score=accuracy.item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
if particle_analysis:
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
for key, value in dict(counter_dict).items():
|
||||
step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists())
|
||||
train_store = new_train_storage_df()
|
||||
|
||||
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric='Test Accuracy', Score=accuracy.item())
|
||||
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
train_store.to_csv(df_store_path)
|
||||
|
||||
if plotting:
|
||||
plot_training_result(df_store_path)
|
||||
|
||||
if particle_analysis:
|
||||
model_path = next(run_path.glob('*ckpt.tp'))
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
_ = test_for_fixpoints(counter_dict, list(latest_model.particles))
|
||||
tqdm.write(str(dict(counter_dict)))
|
||||
zero_ident = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('identity_func')
|
||||
zero_other = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero('other_func')
|
||||
if as_sparse_network_test:
|
||||
acc_pre = validate(model_path, ratio=1)
|
||||
ident_ckpt = set_checkpoint(zero_ident, model_path.parent, -1, final_model=True)
|
||||
ident_acc_post = validate(ident_ckpt, ratio=1)
|
||||
tqdm.write(f'Zero_ident diff = {abs(ident_acc_post-acc_pre)}')
|
||||
other_ckpt = set_checkpoint(zero_other, model_path.parent, -2, final_model=True)
|
||||
other_acc_post = validate(other_ckpt, ratio=1)
|
||||
tqdm.write(f'Zero_other diff = {abs(other_acc_post - acc_pre)}')
|
38
experiments/meta_task_small_utility.py
Normal file
@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
class AddTaskDataset(Dataset):
|
||||
def __init__(self, length=int(1e3)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = torch.randn(size=(2,)).to(torch.float32)
|
||||
return ab, ab.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
btch_x, btch_y = btch_x.to(DEVICE), btch_y.to(DEVICE)
|
||||
y_prd = model(btch_x)
|
||||
|
||||
loss = loss_func(y_prd, btch_y.to(torch.float))
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
stp_log = dict(Metric='Task Loss', Score=loss.item())
|
||||
|
||||
return stp_log, y_prd
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise(NotImplementedError('Get out of here'))
|
405
experiments/meta_task_utility.py
Normal file
@ -0,0 +1,405 @@
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
||||
from sanity_check_weights import test_weights_as_model, extract_weights_from_model
|
||||
|
||||
WORKER = 10
|
||||
BATCHSIZE = 500
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
DATA_PATH = Path('data')
|
||||
DATA_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
PALETTE = sns.color_palette()
|
||||
PALETTE.insert(0, PALETTE.pop(1)) # Orange First
|
||||
|
||||
FINAL_CHECKPOINT_NAME = f'trained_model_ckpt_FINAL.tp'
|
||||
|
||||
|
||||
class AddGaussianNoise(object):
|
||||
def __init__(self, ratio=1e-4):
|
||||
self.ratio = ratio
|
||||
|
||||
def __call__(self, tensor: torch.Tensor):
|
||||
return tensor + (torch.randn_like(tensor, device=tensor.device) * self.ratio)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(ratio={self.ratio}'
|
||||
|
||||
|
||||
class ToFloat:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
||||
class AddTaskDataset(Dataset):
|
||||
def __init__(self, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.prng = np.random.default_rng()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = self.prng.normal(size=(2,)).astype(np.float32)
|
||||
return ab, ab.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
if not final_model:
|
||||
epoch_n = str(epoch_n)
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
if isinstance(epoch_n, str):
|
||||
ckpt_path = Path(out_path) / f'{Path(FINAL_CHECKPOINT_NAME).stem}_{epoch_n}.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
|
||||
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
py_store_path = Path(out_path) / 'exp_py.txt'
|
||||
if not py_store_path.exists():
|
||||
shutil.copy(__file__, py_store_path)
|
||||
return ckpt_path
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def validate(checkpoint_path, valid_loader, metric_class=torchmetrics.Accuracy):
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
|
||||
# initialize metric
|
||||
validmetric = metric_class()
|
||||
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
||||
|
||||
with tqdm(total=len(valid_loader), desc='Validation Run: ') as pbar:
|
||||
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_loader):
|
||||
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
|
||||
y_valid = model(valid_batch_x)
|
||||
|
||||
# metric on current batch
|
||||
measure = validmetric(y_valid.cpu(), valid_batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Measure: {measure}')
|
||||
pbar.update()
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
measure = validmetric.compute()
|
||||
tqdm.write(f"Avg. {validmetric._get_name()} on all data: {measure}")
|
||||
return measure
|
||||
|
||||
|
||||
def new_storage_df(identifier, weight_count):
|
||||
if identifier == 'train':
|
||||
return pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
elif identifier == 'weights':
|
||||
return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))])
|
||||
|
||||
|
||||
def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, keep_n=5, final_model=False,
|
||||
validation_metric=torchmetrics.Accuracy):
|
||||
out_path = Path(out_path)
|
||||
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
|
||||
# Clean up Checkpoints
|
||||
if keep_n > 0:
|
||||
all_ckpts = sorted(list(ckpt_path.parent.iterdir()))
|
||||
while len(all_ckpts) > keep_n:
|
||||
all_ckpts.pop(0).unlink()
|
||||
elif keep_n == 0:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'"keep_n" cannot be negative, but was: {keep_n}')
|
||||
|
||||
result = validate(ckpt_path, valid_loader, metric_class=validation_metric)
|
||||
return result
|
||||
|
||||
|
||||
def plot_training_particle_types(path_to_dataframe):
|
||||
plt.close('all')
|
||||
plt.clf()
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
|
||||
# Set up figure
|
||||
fig, ax = plt.subplots() # initializes figure and plots
|
||||
data = df.loc[df['Metric'].isin(ft.all_types())]
|
||||
fix_types = data['Metric'].unique()
|
||||
data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0)
|
||||
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types],
|
||||
labels=fix_types.tolist(), colors=PALETTE)
|
||||
|
||||
ax.set(ylabel='Particle Count', xlabel='Epoch')
|
||||
ax.yaxis.get_major_locator().set_params(integer=True)
|
||||
# ax.set_title('Particle Type Count')
|
||||
|
||||
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
|
||||
plt.tight_layout()
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
|
||||
|
||||
|
||||
def plot_training_result(path_to_dataframe, metric_name='Accuracy', plot_name=None):
|
||||
plt.clf()
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
|
||||
|
||||
# Check if this is a single lineplot or if aggregated
|
||||
group = ['Epoch', 'Metric']
|
||||
if 'Seed' in df.columns:
|
||||
group.append('Seed')
|
||||
|
||||
# Set up figure
|
||||
fig, ax1 = plt.subplots() # initializes figure and plots
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
|
||||
|
||||
# plots the first set of data
|
||||
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
|
||||
grouped_for_lineplot = data.groupby(group).mean()
|
||||
palette_len_1 = len(grouped_for_lineplot.droplevel(0).reset_index().Metric.unique())
|
||||
|
||||
sns.lineplot(data=grouped_for_lineplot, x='Epoch', y='Score', hue='Metric',
|
||||
palette=PALETTE[:palette_len_1], ax=ax1, ci='sd')
|
||||
|
||||
# plots the second set of data
|
||||
data = df[(df['Metric'] == f'Test {metric_name}') | (df['Metric'] == f'Train {metric_name}')]
|
||||
palette_len_2 = len(data.Metric.unique())
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', hue='Metric',
|
||||
palette=PALETTE[palette_len_1:palette_len_2+palette_len_1], ci='sd')
|
||||
|
||||
ax1.set(yscale='log', ylabel='Losses')
|
||||
# ax1.set_title('Training Lineplot')
|
||||
ax2.set(ylabel=metric_name)
|
||||
if metric_name != 'Accuracy':
|
||||
ax2.set(yscale='log')
|
||||
|
||||
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
|
||||
for ax in [ax1, ax2]:
|
||||
if legend := ax.get_legend():
|
||||
legend.remove()
|
||||
plt.tight_layout()
|
||||
plt.savefig(Path(path_to_dataframe.parent / ('training_lineplot.png' if plot_name is None else plot_name)), dpi=300)
|
||||
|
||||
|
||||
def plot_network_connectivity_by_fixtype(path_to_trained_model):
|
||||
m = torch.load(path_to_trained_model, map_location=DEVICE).eval()
|
||||
# noinspection PyProtectedMember
|
||||
particles = list(m.particles)
|
||||
df = pd.DataFrame(columns=['Type', 'Layer', 'Neuron', 'Name'])
|
||||
|
||||
for prtcl in particles:
|
||||
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", prtcl.name).split('_')]
|
||||
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l-1, w, prtcl.name)
|
||||
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l, c, prtcl.name)
|
||||
for layer in list(df['Layer'].unique()):
|
||||
# Rescale
|
||||
divisor = df.loc[(df['Layer'] == layer), 'Neuron'].max()
|
||||
df.loc[(df['Layer'] == layer), 'Neuron'] /= divisor
|
||||
|
||||
tqdm.write(f'Connectivity Data gathered')
|
||||
df = df.sort_values('Type')
|
||||
n = 0
|
||||
for fixtype in ft.all_types():
|
||||
if df[df['Type'] == fixtype].shape[0] > 0:
|
||||
plt.clf()
|
||||
ax = sns.lineplot(y='Neuron', x='Layer', hue='Name', data=df[df['Type'] == fixtype],
|
||||
legend=False, estimator=None, lw=1)
|
||||
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
|
||||
ax.set_title(fixtype)
|
||||
ax.yaxis.get_major_locator().set_params(integer=True)
|
||||
ax.xaxis.get_major_locator().set_params(integer=True)
|
||||
ax.set_ylabel('Normalized Neuron Position (1/n)') # XAXIS Label
|
||||
lines = ax.get_lines()
|
||||
for line in lines:
|
||||
line.set_color(PALETTE[n])
|
||||
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
|
||||
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["Type"] == fixtype].shape[0] // 2}')
|
||||
n += 1
|
||||
else:
|
||||
# tqdm.write(f'No Connectivity {fixtype}')
|
||||
pass
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def run_particle_dropout_test(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
|
||||
diff_store_path = model_path.parent / 'diff_store.csv'
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
prtcl_dict = defaultdict(lambda: 0)
|
||||
_ = test_for_fixpoints(prtcl_dict, list(latest_model.particles))
|
||||
tqdm.write(str(dict(prtcl_dict)))
|
||||
diff_df = pd.DataFrame(columns=['Particle Type', metric_class()._get_name(), 'Diff'])
|
||||
|
||||
acc_pre = validate(model_path, valid_loader, metric_class=metric_class).item()
|
||||
diff_df.loc[diff_df.shape[0]] = ('All Organism', acc_pre, 0)
|
||||
|
||||
for fixpoint_type in ft.all_types():
|
||||
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
|
||||
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
|
||||
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
|
||||
acc_post = validate(new_ckpt, valid_loader, metric_class=metric_class).item()
|
||||
acc_diff = abs(acc_post - acc_pre)
|
||||
tqdm.write(f'Zero_ident diff = {acc_diff}')
|
||||
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
|
||||
|
||||
diff_df.to_csv(diff_store_path, mode='w', header=True, index=False)
|
||||
return diff_store_path
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchmetrics.Accuracy):
|
||||
metric_name = metric_class()._get_name()
|
||||
diff_df = pd.read_csv(diff_store_path).sort_values('Particle Type')
|
||||
particle_dict = defaultdict(lambda: 0)
|
||||
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
|
||||
particle_dict = dict(particle_dict)
|
||||
sorted_particle_dict = dict(sorted(particle_dict.items()))
|
||||
tqdm.write(str(sorted_particle_dict))
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots(ncols=2)
|
||||
colors = PALETTE.copy()
|
||||
colors.insert(0, colors.pop(-1))
|
||||
palette_len = len(diff_df['Particle Type'].unique())
|
||||
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
|
||||
|
||||
ax[0].set_title(f'{metric_name} after particle dropout')
|
||||
# ax[0].set_xlabel('Particle Type') # XAXIS Label
|
||||
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
|
||||
|
||||
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
|
||||
colors=PALETTE[:len(sorted_particle_dict)])
|
||||
ax[1].set_title('Particle Count')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(Path(diff_store_path.parent / 'dropout_stacked_barplot.png'), dpi=300)
|
||||
|
||||
|
||||
def run_particle_dropout_and_plot(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
|
||||
diff_store_path = run_particle_dropout_test(model_path, valid_loader=valid_loader, metric_class=metric_class)
|
||||
plot_dropout_stacked_barplot(model_path, diff_store_path, metric_class=metric_class)
|
||||
|
||||
|
||||
def flat_for_store(parameters):
|
||||
return (x.item() for y in parameters for x in y.detach().flatten())
|
||||
|
||||
|
||||
def train_self_replication(model, st_stps, **kwargs) -> dict:
|
||||
self_train_loss = model.combined_self_train(st_stps, **kwargs)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
return stp_log
|
||||
|
||||
|
||||
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
btch_x, btch_y = btch_x.to(DEVICE), btch_y.to(DEVICE)
|
||||
y_prd = model(btch_x)
|
||||
|
||||
loss = loss_func(y_prd, btch_y.to(torch.long))
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
stp_log = dict(Metric='Task Loss', Score=loss.item())
|
||||
|
||||
return stp_log, y_prd
|
||||
|
||||
|
||||
def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
|
||||
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||
activation_vector = torch.as_tensor([[0, 0, 0, 0, 1]], dtype=torch.float32, device=DEVICE)
|
||||
binary_images = []
|
||||
real_images = []
|
||||
with torch.no_grad():
|
||||
# noinspection PyProtectedMember
|
||||
for cell in latest_model._meta_layer_first.meta_cell_list:
|
||||
cell_image_binary = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
|
||||
cell_image_real = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
|
||||
for idx, particle in enumerate(cell.particles):
|
||||
if particle.is_fixpoint == ft.identity_func:
|
||||
cell_image_binary[idx] += 1
|
||||
cell_image_real[idx] = particle(activation_vector).abs().squeeze().item()
|
||||
binary_images.append(cell_image_binary.reshape((15, 15)))
|
||||
real_images.append(cell_image_real.reshape((15, 15)))
|
||||
|
||||
binary_images = torch.stack(binary_images)
|
||||
real_images = torch.stack(real_images)
|
||||
|
||||
binary_image = torch.sum(binary_images, keepdim=True, dim=0)
|
||||
real_image = torch.sum(real_images, keepdim=True, dim=0)
|
||||
|
||||
mnist_images = [x for x, _ in dataloader]
|
||||
mnist_mean = torch.cat(mnist_images).reshape(10000, 15, 15).abs().sum(dim=0)
|
||||
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
|
||||
for idx, (image, title) in enumerate(zip([binary_image, real_image, mnist_mean],
|
||||
["Particle Count", "Particle Value", "MNIST mean"])):
|
||||
img = axs[idx].imshow(image.squeeze().detach().cpu())
|
||||
img.axes.axis('off')
|
||||
img.axes.set_title('Random Noise')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def plot_training_results_over_n_seeds(exp_path, df_train_store_name='train_store.csv', metric_name='Accuracy'):
|
||||
combined_df_store_path = exp_path / f'comb_train_{exp_path.stem[:-1]}n.csv'
|
||||
# noinspection PyUnboundLocalVariable
|
||||
found_train_stores = exp_path.rglob(df_train_store_name)
|
||||
train_dfs = []
|
||||
for found_train_store in found_train_stores:
|
||||
train_store_df = pd.read_csv(found_train_store, index_col=False)
|
||||
train_store_df['Seed'] = int(found_train_store.parent.name)
|
||||
train_dfs.append(train_store_df)
|
||||
combined_train_df = pd.concat(train_dfs)
|
||||
combined_train_df.to_csv(combined_df_store_path, index=False)
|
||||
plot_training_result(combined_df_store_path, metric_name=metric_name,
|
||||
plot_name=f"{combined_df_store_path.stem}.png"
|
||||
)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def sanity_weight_swap(exp_path, dataloader, metric_class=torchmetrics.Accuracy):
|
||||
# noinspection PyProtectedMember
|
||||
metric_name = metric_class()._get_name()
|
||||
found_models = exp_path.rglob(f'*{FINAL_CHECKPOINT_NAME}')
|
||||
df = pd.DataFrame(columns=['Seed', 'Model', metric_name])
|
||||
for model_idx, found_model in enumerate(found_models):
|
||||
model = torch.load(found_model, map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
|
||||
results = test_weights_as_model(model, weights, dataloader, metric_class=metric_class)
|
||||
for model_name, measurement in results.items():
|
||||
df.loc[df.shape[0]] = (model_idx, model_name, measurement)
|
||||
df.loc[df.shape[0]] = (model_idx, 'Difference', np.abs(np.subtract(*results.values())))
|
||||
|
||||
df.to_csv(exp_path / 'sanity_weight_swap.csv', index=False)
|
||||
_ = sns.boxplot(data=df, x='Model', y=metric_name)
|
||||
plt.tight_layout()
|
||||
plt.savefig(exp_path / 'sanity_weight_swap.png', dpi=300)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
@ -1,177 +0,0 @@
|
||||
import os.path
|
||||
import pickle
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.helpers import check_folder, summary_fixpoint_experiment, summary_fixpoint_percentage
|
||||
from functionalities_test import test_for_fixpoints
|
||||
from network import Net
|
||||
from visualization import plot_loss, bar_chart_fixpoints, line_chart_fixpoints
|
||||
from visualization import plot_3d_self_train
|
||||
|
||||
|
||||
class MixedSettingExperiment:
|
||||
def __init__(self, population_size, net_i_size, net_h_size, net_o_size, learning_rate, train_nets,
|
||||
epochs, SA_steps, ST_steps_between_SA, log_step_size, directory_name):
|
||||
super().__init__()
|
||||
self.population_size = population_size
|
||||
|
||||
self.net_input_size = net_i_size
|
||||
self.net_hidden_size = net_h_size
|
||||
self.net_out_size = net_o_size
|
||||
self.net_learning_rate = learning_rate
|
||||
self.train_nets = train_nets
|
||||
self.epochs = epochs
|
||||
self.SA_steps = SA_steps
|
||||
self.ST_steps_between_SA = ST_steps_between_SA
|
||||
self.log_step_size = log_step_size
|
||||
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
self.loss_history = []
|
||||
|
||||
self.fixpoint_counters_history = []
|
||||
|
||||
self.directory_name = directory_name
|
||||
os.mkdir(self.directory_name)
|
||||
|
||||
self.nets = []
|
||||
self.populate_environment()
|
||||
|
||||
self.fixpoint_percentage()
|
||||
self.weights_evolution_3d_experiment()
|
||||
self.count_fixpoints()
|
||||
self.visualize_loss()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating mixed experiment %s" % i)
|
||||
|
||||
net_name = f"mixed_net_{str(i)}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
self.nets.append(net)
|
||||
|
||||
loop_epochs = tqdm(range(self.epochs))
|
||||
for j in loop_epochs:
|
||||
loop_epochs.set_description("Running mixed experiment %s" % j)
|
||||
|
||||
for i in loop_population_size:
|
||||
net = self.nets[i]
|
||||
|
||||
if self.train_nets == "before_SA":
|
||||
for _ in range(self.ST_steps_between_SA):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
net.self_application(self.SA_steps, self.log_step_size)
|
||||
|
||||
elif self.train_nets == "after_SA":
|
||||
net.self_application(self.SA_steps, self.log_step_size)
|
||||
for _ in range(self.ST_steps_between_SA):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
print(
|
||||
f"\nLast weight matrix (epoch: {j}):\n{net.input_weight_matrix()}\nLossHistory: {net.loss_history[-10:]}")
|
||||
test_for_fixpoints(self.fixpoint_counters, self.nets)
|
||||
# Rounding the result not to run into other problems later regarding the exact representation of floating number
|
||||
fixpoints_percentage = round((self.fixpoint_counters["fix_zero"] + self.fixpoint_counters[
|
||||
"fix_sec"]) / self.population_size, 1)
|
||||
self.fixpoint_counters_history.append(fixpoints_percentage)
|
||||
|
||||
# Resetting the fixpoint counter. Last iteration not to be reset - it is important for the bar_chart_fixpoints().
|
||||
if j < self.epochs:
|
||||
self.reset_fixpoint_counters()
|
||||
|
||||
def weights_evolution_3d_experiment(self):
|
||||
exp_name = f"Mixed {str(len(self.nets))}"
|
||||
|
||||
# This batch size is not relevant for mixed settings because during an epoch there are more steps of SA & ST happening
|
||||
# and only they need the batch size. To not affect the number of epochs shown in the 3D plot, will send
|
||||
# forward the number "1" for batch size with the variable <irrelevant_batch_size>
|
||||
irrelevant_batch_size = 1
|
||||
plot_3d_self_train(self.nets, exp_name, self.directory_name, irrelevant_batch_size, True)
|
||||
|
||||
def count_fixpoints(self):
|
||||
exp_details = f"SA steps: {self.SA_steps}; ST steps: {self.ST_steps_between_SA}"
|
||||
|
||||
test_for_fixpoints(self.fixpoint_counters, self.nets)
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory_name, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
def fixpoint_percentage(self):
|
||||
line_chart_fixpoints(self.fixpoint_counters_history, self.epochs, self.ST_steps_between_SA,
|
||||
self.SA_steps, self.directory_name, self.population_size)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.nets)):
|
||||
net_loss_history = self.nets[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
|
||||
plot_loss(self.loss_history, self.directory_name)
|
||||
|
||||
def reset_fixpoint_counters(self):
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
|
||||
def run_mixed_experiment(population_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate, train_nets,
|
||||
epochs, SA_steps, ST_steps_between_SA, batch_size, name_hash, runs, run_name):
|
||||
experiments = {}
|
||||
fixpoints_percentages = []
|
||||
|
||||
check_folder("mixed")
|
||||
|
||||
# Running the experiments
|
||||
for i in range(runs):
|
||||
directory_name = f"experiments/mixed/{run_name}_run_{i}_{str(population_size)}_nets_{SA_steps}_SA_{ST_steps_between_SA}_ST_{str(name_hash)}"
|
||||
|
||||
mixed_experiment = MixedSettingExperiment(
|
||||
population_size,
|
||||
net_input_size,
|
||||
net_hidden_size,
|
||||
net_out_size,
|
||||
net_learning_rate,
|
||||
train_nets,
|
||||
epochs,
|
||||
SA_steps,
|
||||
ST_steps_between_SA,
|
||||
batch_size,
|
||||
directory_name
|
||||
)
|
||||
pickle.dump(mixed_experiment, open(f"{directory_name}/full_experiment_pickle.p", "wb"))
|
||||
experiments[i] = mixed_experiment
|
||||
|
||||
# Building history of fixpoint percentages for summary
|
||||
fixpoint_counters_history = mixed_experiment.fixpoint_counters_history
|
||||
if not fixpoints_percentages:
|
||||
fixpoints_percentages = mixed_experiment.fixpoint_counters_history
|
||||
else:
|
||||
# Using list comprehension to make the sum of all the percentages
|
||||
fixpoints_percentages = [fixpoints_percentages[i] + fixpoint_counters_history[i] for i in
|
||||
range(len(fixpoints_percentages))]
|
||||
|
||||
# Building a summary of all the runs
|
||||
directory_name = f"experiments/mixed/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{str(name_hash)}"
|
||||
os.mkdir(directory_name)
|
||||
|
||||
summary_pre_title = "mixed"
|
||||
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory_name,
|
||||
summary_pre_title)
|
||||
summary_fixpoint_percentage(runs, epochs, fixpoints_percentages, ST_steps_between_SA, SA_steps, directory_name,
|
||||
population_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
@ -1,151 +0,0 @@
|
||||
import copy
|
||||
import os.path
|
||||
import pickle
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.helpers import check_folder, summary_fixpoint_experiment
|
||||
from functionalities_test import test_for_fixpoints, is_identity_function
|
||||
from network import Net
|
||||
from visualization import bar_chart_fixpoints, box_plot, write_file
|
||||
|
||||
|
||||
def add_noise(input_data, epsilon=pow(10, -5)):
|
||||
|
||||
output = copy.deepcopy(input_data)
|
||||
for k in range(len(input_data)):
|
||||
output[k][0] += random.random() * epsilon
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RobustnessExperiment:
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
ST_steps, directory_name) -> None:
|
||||
self.population_size = population_size
|
||||
self.log_step_size = log_step_size
|
||||
self.net_input_size = net_input_size
|
||||
self.net_hidden_size = net_hidden_size
|
||||
self.net_out_size = net_out_size
|
||||
|
||||
self.net_learning_rate = net_learning_rate
|
||||
|
||||
self.ST_steps = ST_steps
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
self.id_functions = []
|
||||
|
||||
self.directory_name = directory_name
|
||||
os.mkdir(self.directory_name)
|
||||
|
||||
self.nets = []
|
||||
# Create population:
|
||||
self.populate_environment()
|
||||
print("Nets:\n", self.nets)
|
||||
|
||||
self.count_fixpoints()
|
||||
[print(net.is_fixpoint) for net in self.nets]
|
||||
self.test_robustness()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating robustness experiment %s" % i)
|
||||
|
||||
net_name = f"net_{str(i)}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
self.nets.append(net)
|
||||
|
||||
def test_robustness(self):
|
||||
# test_for_fixpoints(self.fixpoint_counters, self.nets, self.id_functions)
|
||||
|
||||
zero_epsilon = pow(10, -5)
|
||||
data = [[0 for _ in range(10)] for _ in range(len(self.id_functions))]
|
||||
|
||||
for i in range(len(self.id_functions)):
|
||||
for j in range(10):
|
||||
original_net = self.id_functions[i]
|
||||
|
||||
# Creating a clone of the network. Not by copying it, but by creating a completely new network
|
||||
# and changing its weights to the original ones.
|
||||
original_net_clone = Net(original_net.input_size, original_net.hidden_size, original_net.out_size,
|
||||
original_net.name)
|
||||
# Extra safety for the value of the weights
|
||||
original_net_clone.load_state_dict(copy.deepcopy(original_net.state_dict()))
|
||||
|
||||
noisy_weights = add_noise(original_net_clone.input_weight_matrix(), epsilon=pow(10, -j))
|
||||
original_net_clone.apply_weights(noisy_weights)
|
||||
|
||||
# Testing if the new net is still an identity function after applying noise
|
||||
still_id_func = is_identity_function(original_net_clone, zero_epsilon)
|
||||
|
||||
# If the net is still an id. func. after applying the first run of noise, continue to apply it until otherwise
|
||||
while still_id_func and data[i][j] <= 1000:
|
||||
data[i][j] += 1
|
||||
|
||||
original_net_clone = original_net_clone.self_application(1, self.log_step_size)
|
||||
|
||||
still_id_func = is_identity_function(original_net_clone, zero_epsilon)
|
||||
|
||||
print(f"Data {data}")
|
||||
|
||||
if data.count(0) == 10:
|
||||
print(f"There is no network resisting the robustness test.")
|
||||
text = f"For this population of \n {self.population_size} networks \n there is no" \
|
||||
f" network resisting the robustness test."
|
||||
write_file(text, self.directory_name)
|
||||
else:
|
||||
box_plot(data, self.directory_name, self.population_size)
|
||||
|
||||
def count_fixpoints(self):
|
||||
exp_details = f"ST steps: {self.ST_steps}"
|
||||
|
||||
self.id_functions = test_for_fixpoints(self.fixpoint_counters, self.nets)
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory_name, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
|
||||
def run_robustness_experiment(population_size, batch_size, net_input_size, net_hidden_size, net_out_size,
|
||||
net_learning_rate, epochs, runs, run_name, name_hash):
|
||||
experiments = {}
|
||||
|
||||
check_folder("robustness")
|
||||
|
||||
# Running the experiments
|
||||
for i in range(runs):
|
||||
ST_directory_name = f"experiments/robustness/{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
|
||||
robustness_experiment = RobustnessExperiment(
|
||||
population_size,
|
||||
batch_size,
|
||||
net_input_size,
|
||||
net_hidden_size,
|
||||
net_out_size,
|
||||
net_learning_rate,
|
||||
epochs,
|
||||
ST_directory_name
|
||||
)
|
||||
pickle.dump(robustness_experiment, open(f"{ST_directory_name}/full_experiment_pickle.p", "wb"))
|
||||
experiments[i] = robustness_experiment
|
||||
|
||||
# Building a summary of all the runs
|
||||
directory_name = f"experiments/robustness/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{str(name_hash)}"
|
||||
os.mkdir(directory_name)
|
||||
|
||||
summary_pre_title = "robustness"
|
||||
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory_name,
|
||||
summary_pre_title)
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
118
experiments/robustness_tester.py
Normal file
@ -0,0 +1,118 @@
|
||||
import pandas as pd
|
||||
import torch
|
||||
import random
|
||||
import copy
|
||||
|
||||
from tqdm import tqdm
|
||||
from functionalities_test import (is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent,
|
||||
FixTypes as FT)
|
||||
from network import Net
|
||||
from torch.nn import functional as F
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
def generate_perfekt_synthetic_fixpoint_weights():
|
||||
return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0],
|
||||
[1.0], [0.0], [0.0], [0.0],
|
||||
[1.0], [0.0]
|
||||
], dtype=torch.float32)
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
|
||||
def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||
model = torch.load(model_path, map_location='cpu')
|
||||
networks = [x for x in model.particles if x.is_fixpoint == FT.identity_func]
|
||||
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
row_headers = []
|
||||
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
|
||||
'Time to convergence', 'Time as fixpoint'])
|
||||
with tqdm(total=(seeds * noise_levels * len(networks)), desc='Per Particle Robustness') as pbar:
|
||||
for setting, fixpoint in enumerate(networks): # 1 / n
|
||||
row_headers.append(fixpoint.name)
|
||||
for seed in range(seeds): # n / 1
|
||||
for noise_level in range(noise_levels):
|
||||
steps = 0
|
||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||
f"{fixpoint.name}_clone_noise_1e-{noise_level}")
|
||||
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||
clone = clone.apply_noise(pow(10, -noise_level))
|
||||
|
||||
while not is_zero_fixpoint(clone) and not is_divergent(clone):
|
||||
# -> before
|
||||
clone_weight_pre_application = clone.input_weight_matrix()
|
||||
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
|
||||
|
||||
clone.self_application(1, log_step_size)
|
||||
time_to_vergence[setting][noise_level] += 1
|
||||
# -> after
|
||||
clone_weight_post_application = clone.input_weight_matrix()
|
||||
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
||||
|
||||
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
|
||||
|
||||
if is_identity_function(clone):
|
||||
time_as_fixpoint[setting][noise_level] += 1
|
||||
# When this raises a Type Error, we found a second order fixpoint!
|
||||
steps += 1
|
||||
|
||||
df.loc[df.shape[0]] = [f'{setting}_{seed}', fr'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
steps, absolute_loss,
|
||||
time_to_vergence[setting][noise_level],
|
||||
time_as_fixpoint[setting][noise_level]]
|
||||
pbar.update(1)
|
||||
|
||||
# Get the measuremts at the highest time_time_to_vergence
|
||||
df_sorted = df.sort_values('Self Train Steps', ascending=False).drop_duplicates(['setting', 'Noise Level'])
|
||||
df_melted = df_sorted.reset_index().melt(id_vars=['setting', 'Noise Level', 'Self Train Steps'],
|
||||
value_vars=['Time to convergence', 'Time as fixpoint'],
|
||||
var_name="Measurement",
|
||||
value_name="Steps").sort_values('Noise Level')
|
||||
|
||||
df_melted.to_csv(model_path.parent / 'robustness_boxplot.csv', index=False)
|
||||
|
||||
# Plotting
|
||||
# plt.rcParams.update({
|
||||
# "text.usetex": True,
|
||||
# "font.family": "sans-serif",
|
||||
# "font.size": 12,
|
||||
# "font.weight": 'bold',
|
||||
# "font.sans-serif": ["Helvetica"]})
|
||||
plt.clf()
|
||||
sns.set(style='whitegrid', font_scale=1)
|
||||
_ = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
|
||||
plt.tight_layout()
|
||||
|
||||
# sns.set(rc={'figure.figsize': (10, 50)})
|
||||
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
|
||||
# col='noise_level', col_wrap=3, showfliers=False)
|
||||
|
||||
filename = f"robustness_boxplot.png"
|
||||
filepath = model_path.parent / filename
|
||||
plt.savefig(str(filepath))
|
||||
plt.close('all')
|
||||
return time_as_fixpoint, time_to_vergence
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise NotImplementedError('Get out of here!')
|
@ -1,120 +0,0 @@
|
||||
import os.path
|
||||
import pickle
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.helpers import check_folder, summary_fixpoint_experiment
|
||||
from functionalities_test import test_for_fixpoints
|
||||
from network import Net
|
||||
from visualization import bar_chart_fixpoints
|
||||
from visualization import plot_3d_self_application
|
||||
|
||||
|
||||
class SelfApplicationExperiment:
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size,
|
||||
net_learning_rate, application_steps, train_nets, directory_name, training_steps
|
||||
) -> None:
|
||||
self.population_size = population_size
|
||||
self.log_step_size = log_step_size
|
||||
self.net_input_size = net_input_size
|
||||
self.net_hidden_size = net_hidden_size
|
||||
self.net_out_size = net_out_size
|
||||
|
||||
self.net_learning_rate = net_learning_rate
|
||||
self.SA_steps = application_steps #
|
||||
|
||||
self.train_nets = train_nets
|
||||
self.ST_steps = training_steps
|
||||
|
||||
self.directory_name = directory_name
|
||||
os.mkdir(self.directory_name)
|
||||
|
||||
""" Creating the nets & making the SA steps & (maybe) also training the networks. """
|
||||
self.nets = []
|
||||
# Create population:
|
||||
self.populate_environment()
|
||||
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
self.weights_evolution_3d_experiment()
|
||||
self.count_fixpoints()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating SA experiment %s" % i)
|
||||
|
||||
net_name = f"SA_net_{str(i)}"
|
||||
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name
|
||||
)
|
||||
for _ in range(self.SA_steps):
|
||||
input_data = net.input_weight_matrix()
|
||||
target_data = net.create_target_weights(input_data)
|
||||
|
||||
if self.train_nets == "before_SA":
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
net.self_application(self.SA_steps, self.log_step_size)
|
||||
elif self.train_nets == "after_SA":
|
||||
net.self_application(self.SA_steps, self.log_step_size)
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
else:
|
||||
net.self_application(self.SA_steps, self.log_step_size)
|
||||
|
||||
self.nets.append(net)
|
||||
|
||||
def weights_evolution_3d_experiment(self):
|
||||
exp_name = f"SA_{str(len(self.nets))}_nets_3d_weights_PCA"
|
||||
plot_3d_self_application(self.nets, exp_name, self.directory_name, self.log_step_size)
|
||||
|
||||
def count_fixpoints(self):
|
||||
test_for_fixpoints(self.fixpoint_counters, self.nets)
|
||||
exp_details = f"{self.SA_steps} SA steps"
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory_name, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
|
||||
def run_SA_experiment(population_size, batch_size, net_input_size, net_hidden_size, net_out_size,
|
||||
net_learning_rate, runs, run_name, name_hash, application_steps, train_nets, training_steps):
|
||||
experiments = {}
|
||||
|
||||
check_folder("self_application")
|
||||
|
||||
# Running the experiments
|
||||
for i in range(runs):
|
||||
directory_name = f"experiments/self_application/{run_name}_run_{i}_{str(population_size)}_nets_{application_steps}_SA_{str(name_hash)}"
|
||||
|
||||
SA_experiment = SelfApplicationExperiment(
|
||||
population_size,
|
||||
batch_size,
|
||||
net_input_size,
|
||||
net_hidden_size,
|
||||
net_out_size,
|
||||
net_learning_rate,
|
||||
application_steps,
|
||||
train_nets,
|
||||
directory_name,
|
||||
training_steps
|
||||
)
|
||||
pickle.dump(SA_experiment, open(f"{directory_name}/full_experiment_pickle.p", "wb"))
|
||||
experiments[i] = SA_experiment
|
||||
|
||||
# Building a summary of all the runs
|
||||
directory_name = f"experiments/self_application/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{application_steps}_SA_{str(name_hash)}"
|
||||
os.mkdir(directory_name)
|
||||
|
||||
summary_pre_title = "SA"
|
||||
summary_fixpoint_experiment(runs, population_size, application_steps, experiments, net_learning_rate,
|
||||
directory_name,
|
||||
summary_pre_title)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
@ -1,116 +0,0 @@
|
||||
import os.path
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.helpers import check_folder, summary_fixpoint_experiment
|
||||
from functionalities_test import test_for_fixpoints
|
||||
from network import Net
|
||||
from visualization import plot_loss, bar_chart_fixpoints
|
||||
from visualization import plot_3d_self_train
|
||||
|
||||
|
||||
|
||||
class SelfTrainExperiment:
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, directory_name) -> None:
|
||||
self.population_size = population_size
|
||||
self.log_step_size = log_step_size
|
||||
self.net_input_size = net_input_size
|
||||
self.net_hidden_size = net_hidden_size
|
||||
self.net_out_size = net_out_size
|
||||
|
||||
self.net_learning_rate = net_learning_rate
|
||||
self.epochs = epochs
|
||||
|
||||
self.loss_history = []
|
||||
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
self.directory_name = directory_name
|
||||
os.mkdir(self.directory_name)
|
||||
|
||||
self.nets = []
|
||||
# Create population:
|
||||
self.populate_environment()
|
||||
|
||||
self.weights_evolution_3d_experiment()
|
||||
self.count_fixpoints()
|
||||
self.visualize_loss()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating ST experiment %s" % i)
|
||||
|
||||
net_name = f"ST_net_{str(i)}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
|
||||
for _ in range(self.epochs):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
print(f"\nLast weight matrix (epoch: {self.epochs}):\n{net.input_weight_matrix()}\nLossHistory: {net.loss_history[-10:]}")
|
||||
self.nets.append(net)
|
||||
|
||||
def weights_evolution_3d_experiment(self):
|
||||
exp_name = f"ST_{str(len(self.nets))}_nets_3d_weights_PCA"
|
||||
return plot_3d_self_train(self.nets, exp_name, self.directory_name, self.log_step_size)
|
||||
|
||||
def count_fixpoints(self):
|
||||
test_for_fixpoints(self.fixpoint_counters, self.nets)
|
||||
exp_details = f"Self-train for {self.epochs} epochs"
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory_name, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.nets)):
|
||||
net_loss_history = self.nets[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
|
||||
plot_loss(self.loss_history, self.directory_name)
|
||||
|
||||
|
||||
def run_ST_experiment(population_size, batch_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, runs, run_name, name_hash):
|
||||
experiments = {}
|
||||
logging_directory = Path('output') / 'self_training'
|
||||
logging_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Running the experiments
|
||||
for i in range(runs):
|
||||
experiment_name = f"{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
this_exp_directory = logging_directory / experiment_name
|
||||
ST_experiment = SelfTrainExperiment(
|
||||
population_size,
|
||||
batch_size,
|
||||
net_input_size,
|
||||
net_hidden_size,
|
||||
net_out_size,
|
||||
net_learning_rate,
|
||||
epochs,
|
||||
this_exp_directory
|
||||
)
|
||||
with (this_exp_directory / 'full_experiment_pickle.p').open('wb') as f:
|
||||
pickle.dump(ST_experiment, f)
|
||||
experiments[i] = ST_experiment
|
||||
|
||||
# Building a summary of all the runs
|
||||
summary_name = f"/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
summary_directory_name = logging_directory / summary_name
|
||||
summary_directory_name.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
summary_pre_title = "ST"
|
||||
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, summary_directory_name,
|
||||
summary_pre_title)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
@ -1,114 +0,0 @@
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.helpers import check_folder, summary_fixpoint_experiment
|
||||
from functionalities_test import test_for_fixpoints
|
||||
from network import SecondaryNet
|
||||
from visualization import plot_loss, bar_chart_fixpoints
|
||||
from visualization import plot_3d_self_train
|
||||
|
||||
|
||||
class SelfTrainExperimentSecondary:
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, directory: Path) -> None:
|
||||
self.population_size = population_size
|
||||
self.log_step_size = log_step_size
|
||||
self.net_input_size = net_input_size
|
||||
self.net_hidden_size = net_hidden_size
|
||||
self.net_out_size = net_out_size
|
||||
|
||||
self.net_learning_rate = net_learning_rate
|
||||
self.epochs = epochs
|
||||
|
||||
self.loss_history = []
|
||||
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
self.directory_name = Path(directory)
|
||||
self.directory_name.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.nets = []
|
||||
# Create population:
|
||||
self.populate_environment()
|
||||
|
||||
self.weights_evolution_3d_experiment()
|
||||
self.count_fixpoints()
|
||||
self.visualize_loss()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating ST experiment %s" % i)
|
||||
|
||||
net_name = f"ST_net_{str(i)}"
|
||||
net = SecondaryNet(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
|
||||
for _ in range(self.epochs):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
print(f"\nLast weight matrix (epoch: {self.epochs}):\n{net.input_weight_matrix()}\nLossHistory: {net.loss_history[-10:]}")
|
||||
self.nets.append(net)
|
||||
|
||||
def weights_evolution_3d_experiment(self):
|
||||
exp_name = f"ST_{str(len(self.nets))}_nets_3d_weights_PCA"
|
||||
return plot_3d_self_train(self.nets, exp_name, self.directory_name, self.log_step_size)
|
||||
|
||||
def count_fixpoints(self):
|
||||
test_for_fixpoints(self.fixpoint_counters, self.nets)
|
||||
exp_details = f"Self-train for {self.epochs} epochs"
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory_name, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.nets)):
|
||||
net_loss_history = self.nets[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
|
||||
plot_loss(self.loss_history, self.directory_name)
|
||||
|
||||
|
||||
def run_ST_experiment(population_size, batch_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, runs, run_name, name_hash):
|
||||
experiments = {}
|
||||
logging_directory = Path('output') / 'self_training'
|
||||
logging_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Running the experiments
|
||||
for i in range(runs):
|
||||
experiment_name = f"{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
this_exp_directory = logging_directory / experiment_name
|
||||
ST_experiment = SelfTrainExperimentSecondary(
|
||||
population_size,
|
||||
batch_size,
|
||||
net_input_size,
|
||||
net_hidden_size,
|
||||
net_out_size,
|
||||
net_learning_rate,
|
||||
epochs,
|
||||
this_exp_directory
|
||||
)
|
||||
with (this_exp_directory / 'full_experiment_pickle.p').open('wb') as f:
|
||||
pickle.dump(ST_experiment, f)
|
||||
experiments[i] = ST_experiment
|
||||
|
||||
# Building a summary of all the runs
|
||||
summary_name = f"/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
summary_directory_name = logging_directory / summary_name
|
||||
summary_directory_name.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
summary_pre_title = "ST"
|
||||
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, summary_directory_name,
|
||||
summary_pre_title)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
@ -1,190 +0,0 @@
|
||||
import random
|
||||
import os.path
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.helpers import check_folder, summary_fixpoint_percentage, summary_fixpoint_experiment
|
||||
from functionalities_test import test_for_fixpoints
|
||||
from network import Net
|
||||
from visualization import plot_loss, bar_chart_fixpoints, plot_3d_soup, line_chart_fixpoints
|
||||
|
||||
|
||||
class SoupExperiment:
|
||||
def __init__(self, population_size, net_i_size, net_h_size, net_o_size, learning_rate, attack_chance,
|
||||
train_nets, ST_steps, epochs, log_step_size, directory: Union[str, Path]):
|
||||
super().__init__()
|
||||
self.population_size = population_size
|
||||
|
||||
self.net_input_size = net_i_size
|
||||
self.net_hidden_size = net_h_size
|
||||
self.net_out_size = net_o_size
|
||||
self.net_learning_rate = learning_rate
|
||||
self.attack_chance = attack_chance
|
||||
self.train_nets = train_nets
|
||||
# self.SA_steps = SA_steps
|
||||
self.ST_steps = ST_steps
|
||||
self.epochs = epochs
|
||||
self.log_step_size = log_step_size
|
||||
|
||||
self.loss_history = []
|
||||
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
# <self.fixpoint_counters_history> is used for keeping track of the amount of fixpoints in %
|
||||
self.fixpoint_counters_history = []
|
||||
|
||||
self.directory = Path(directory)
|
||||
self.directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.population = []
|
||||
self.populate_environment()
|
||||
|
||||
self.evolve()
|
||||
self.fixpoint_percentage()
|
||||
self.weights_evolution_3d_experiment()
|
||||
self.count_fixpoints()
|
||||
self.visualize_loss()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in tqdm(range(self.population_size)):
|
||||
loop_population_size.set_description("Populating soup experiment %s" % i)
|
||||
|
||||
net_name = f"soup_network_{i}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
self.population.append(net)
|
||||
|
||||
def population_self_train(self):
|
||||
# Self-training each network in the population
|
||||
for j in range(self.population_size):
|
||||
net = self.population[j]
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
def population_attack(self):
|
||||
# A network attacking another network with a given percentage
|
||||
if random.randint(1, 100) <= self.attack_chance:
|
||||
random_net1, random_net2 = random.sample(range(self.population_size), 2)
|
||||
random_net1 = self.population[random_net1]
|
||||
random_net2 = self.population[random_net2]
|
||||
print(f"\n Attack: {random_net1.name} -> {random_net2.name}")
|
||||
random_net1.attack(random_net2)
|
||||
|
||||
def evolve(self):
|
||||
""" Evolving consists of attacking & self-training. """
|
||||
|
||||
loop_epochs = tqdm(range(self.epochs))
|
||||
for i in loop_epochs:
|
||||
loop_epochs.set_description("Evolving soup %s" % i)
|
||||
|
||||
# A network attacking another network with a given percentage
|
||||
self.population_attack()
|
||||
|
||||
# Self-training each network in the population
|
||||
self.population_self_train()
|
||||
|
||||
# Testing for fixpoints after each batch of ST steps to see relevant data
|
||||
if i % self.ST_steps == 0:
|
||||
test_for_fixpoints(self.fixpoint_counters, self.population)
|
||||
fixpoints_percentage = round(self.fixpoint_counters["identity_func"] / self.population_size, 1)
|
||||
self.fixpoint_counters_history.append(fixpoints_percentage)
|
||||
|
||||
# Resetting the fixpoint counter. Last iteration not to be reset -
|
||||
# it is important for the bar_chart_fixpoints().
|
||||
if i < self.epochs:
|
||||
self.reset_fixpoint_counters()
|
||||
|
||||
def weights_evolution_3d_experiment(self):
|
||||
exp_name = f"soup_{self.population_size}_nets_{self.ST_steps}_training_{self.epochs}_epochs"
|
||||
return plot_3d_soup(self.population, exp_name, self.directory)
|
||||
|
||||
def count_fixpoints(self):
|
||||
test_for_fixpoints(self.fixpoint_counters, self.population)
|
||||
exp_details = f"Evolution steps: {self.epochs} epochs"
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
def fixpoint_percentage(self):
|
||||
runs = self.epochs / self.ST_steps
|
||||
SA_steps = None
|
||||
line_chart_fixpoints(self.fixpoint_counters_history, runs, self.ST_steps, SA_steps, self.directory,
|
||||
self.population_size)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.population)):
|
||||
net_loss_history = self.population[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
|
||||
plot_loss(self.loss_history, self.directory)
|
||||
|
||||
def reset_fixpoint_counters(self):
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
|
||||
def run_soup_experiment(population_size, attack_chance, net_input_size, net_hidden_size, net_out_size,
|
||||
net_learning_rate, epochs, batch_size, runs, run_name, name_hash, ST_steps, train_nets):
|
||||
experiments = {}
|
||||
fixpoints_percentages = []
|
||||
|
||||
check_folder("soup")
|
||||
|
||||
# Running the experiments
|
||||
for i in range(runs):
|
||||
# FIXME: Make this a pathlib.Path() Operation
|
||||
directory_name = f"experiments/soup/{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
|
||||
soup_experiment = SoupExperiment(
|
||||
population_size,
|
||||
net_input_size,
|
||||
net_hidden_size,
|
||||
net_out_size,
|
||||
net_learning_rate,
|
||||
attack_chance,
|
||||
train_nets,
|
||||
ST_steps,
|
||||
epochs,
|
||||
batch_size,
|
||||
directory_name
|
||||
)
|
||||
pickle.dump(soup_experiment, open(f"{directory_name}/full_experiment_pickle.p", "wb"))
|
||||
experiments[i] = soup_experiment
|
||||
|
||||
# Building history of fixpoint percentages for summary
|
||||
fixpoint_counters_history = soup_experiment.fixpoint_counters_history
|
||||
if not fixpoints_percentages:
|
||||
fixpoints_percentages = soup_experiment.fixpoint_counters_history
|
||||
else:
|
||||
# Using list comprehension to make the sum of all the percentages
|
||||
fixpoints_percentages = [fixpoints_percentages[i] + fixpoint_counters_history[i] for i in
|
||||
range(len(fixpoints_percentages))]
|
||||
|
||||
# Creating a folder for the summary of the current runs
|
||||
# FIXME: Make this a pathlib.Path() Operation
|
||||
directory_name = f"experiments/soup/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
os.mkdir(directory_name)
|
||||
|
||||
# Building a summary of all the runs
|
||||
summary_pre_title = "soup"
|
||||
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory_name,
|
||||
summary_pre_title)
|
||||
SA_steps = None
|
||||
summary_fixpoint_percentage(runs, epochs, fixpoints_percentages, ST_steps, SA_steps, directory_name,
|
||||
population_size)
|
||||
|
@ -1,50 +0,0 @@
|
||||
import random
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.soup_exp import SoupExperiment
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
|
||||
class MeltingSoupExperiment(SoupExperiment):
|
||||
|
||||
def __init__(self, melt_chance, *args, keep_population_size=True, **kwargs):
|
||||
super(MeltingSoupExperiment, self).__init__(*args, **kwargs)
|
||||
self.keep_population_size = keep_population_size
|
||||
self.melt_chance = melt_chance
|
||||
|
||||
def population_melt(self):
|
||||
# A network melting with another network by a given percentage
|
||||
if random.randint(1, 100) <= self.melt_chance:
|
||||
random_net1_idx, random_net2_idx, destroy_idx = random.sample(range(self.population_size), 3)
|
||||
random_net1 = self.population[random_net1_idx]
|
||||
random_net2 = self.population[random_net2_idx]
|
||||
print(f"\n Melt: {random_net1.name} -> {random_net2.name}")
|
||||
melted_network = random_net1.melt(random_net2)
|
||||
if self.keep_population_size:
|
||||
del self.population[destroy_idx]
|
||||
self.population.append(melted_network)
|
||||
|
||||
def evolve(self):
|
||||
""" Evolving consists of attacking, melting & self-training. """
|
||||
|
||||
loop_epochs = tqdm(range(self.epochs))
|
||||
for i in loop_epochs:
|
||||
loop_epochs.set_description("Evolving soup %s" % i)
|
||||
|
||||
self.population_attack()
|
||||
|
||||
self.population_melt()
|
||||
|
||||
self.population_self_train()
|
||||
|
||||
# Testing for fixpoints after each batch of ST steps to see relevant data
|
||||
if i % self.ST_steps == 0:
|
||||
test_for_fixpoints(self.fixpoint_counters, self.population)
|
||||
fixpoints_percentage = round(self.fixpoint_counters["identity_func"] / self.population_size, 1)
|
||||
self.fixpoint_counters_history.append(fixpoints_percentage)
|
||||
|
||||
# Resetting the fixpoint counter. Last iteration not to be reset -
|
||||
# it is important for the bar_chart_fixpoints().
|
||||
if i < self.epochs:
|
||||
self.reset_fixpoint_counters()
|
BIN
figures/connectivity/identity.png
Normal file
After Width: | Height: | Size: 98 KiB |
BIN
figures/connectivity/other.png
Normal file
After Width: | Height: | Size: 97 KiB |
BIN
figures/connectivity/training_lineplot.png
Normal file
After Width: | Height: | Size: 198 KiB |
BIN
figures/connectivity/training_particle_type_lp.png
Normal file
After Width: | Height: | Size: 91 KiB |
After Width: | Height: | Size: 187 KiB |
After Width: | Height: | Size: 93 KiB |
BIN
figures/res_no_res/mn_st_200_8_alpha_100_training_lineplot.png
Normal file
After Width: | Height: | Size: 176 KiB |
After Width: | Height: | Size: 94 KiB |
BIN
figures/sanity/sanity_10hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
figures/sanity/sanity_2hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
figures/sanity/sanity_3hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 22 KiB |
BIN
figures/sanity/sanity_4hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 24 KiB |
BIN
figures/sanity/sanity_6hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 23 KiB |
@ -3,32 +3,34 @@ from typing import Dict, List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from network import Net
|
||||
from network import FixTypes, Net
|
||||
|
||||
|
||||
epsilon_error_margin = pow(10, -5)
|
||||
|
||||
|
||||
def is_divergent(network: Net) -> bool:
|
||||
return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
|
||||
|
||||
|
||||
def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
def is_identity_function(network: Net, epsilon=epsilon_error_margin) -> bool:
|
||||
|
||||
input_data = network.input_weight_matrix()
|
||||
target_data = network.create_target_weights(input_data)
|
||||
predicted_values = network(input_data)
|
||||
|
||||
|
||||
return torch.allclose(target_data.detach(), predicted_values.detach(),
|
||||
rtol=0, atol=epsilon)
|
||||
|
||||
|
||||
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
def is_zero_fixpoint(network: Net, epsilon=epsilon_error_margin) -> bool:
|
||||
target_data = network.create_target_weights(network.input_weight_matrix().detach())
|
||||
result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
|
||||
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
|
||||
return result
|
||||
|
||||
|
||||
def is_secondary_fixpoint(network: Net, epsilon: float = pow(10, -5)) -> bool:
|
||||
def is_secondary_fixpoint(network: Net, epsilon: float = epsilon_error_margin) -> bool:
|
||||
""" Secondary fixpoint check is done like this: compare first INPUT with second OUTPUT.
|
||||
If they are within the boundaries, then is secondary fixpoint. """
|
||||
|
||||
@ -57,21 +59,21 @@ def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
|
||||
|
||||
for net in tqdm(nets, desc='Fixpoint Tester', total=len(nets)):
|
||||
if is_divergent(net):
|
||||
fixpoint_counter["divergent"] += 1
|
||||
net.is_fixpoint = "divergent"
|
||||
elif is_identity_function(net): # is default value
|
||||
fixpoint_counter["identity_func"] += 1
|
||||
net.is_fixpoint = "identity_func"
|
||||
id_functions.append(net)
|
||||
fixpoint_counter[FixTypes.divergent] += 1
|
||||
net.is_fixpoint = FixTypes.divergent
|
||||
elif is_zero_fixpoint(net):
|
||||
fixpoint_counter["fix_zero"] += 1
|
||||
net.is_fixpoint = "fix_zero"
|
||||
fixpoint_counter[FixTypes.fix_zero] += 1
|
||||
net.is_fixpoint = FixTypes.fix_zero
|
||||
elif is_identity_function(net): # is default value
|
||||
fixpoint_counter[FixTypes.identity_func] += 1
|
||||
net.is_fixpoint = FixTypes.identity_func
|
||||
id_functions.append(net)
|
||||
elif is_secondary_fixpoint(net):
|
||||
fixpoint_counter["fix_sec"] += 1
|
||||
net.is_fixpoint = "fix_sec"
|
||||
fixpoint_counter[FixTypes.fix_sec] += 1
|
||||
net.is_fixpoint = FixTypes.fix_sec
|
||||
else:
|
||||
fixpoint_counter["other_func"] += 1
|
||||
net.is_fixpoint = "other_func"
|
||||
fixpoint_counter[FixTypes.other_func] += 1
|
||||
net.is_fixpoint = FixTypes.other_func
|
||||
return id_functions
|
||||
|
||||
|
||||
@ -82,14 +84,14 @@ def changing_rate(x_new, x_old):
|
||||
def test_status(net: Net) -> Net:
|
||||
|
||||
if is_divergent(net):
|
||||
net.is_fixpoint = "divergent"
|
||||
net.is_fixpoint = FixTypes.divergent
|
||||
elif is_identity_function(net): # is default value
|
||||
net.is_fixpoint = "identity_func"
|
||||
net.is_fixpoint = FixTypes.identity_func
|
||||
elif is_zero_fixpoint(net):
|
||||
net.is_fixpoint = "fix_zero"
|
||||
net.is_fixpoint = FixTypes.fix_zero
|
||||
elif is_secondary_fixpoint(net):
|
||||
net.is_fixpoint = "fix_sec"
|
||||
net.is_fixpoint = FixTypes.fix_sec
|
||||
else:
|
||||
net.is_fixpoint = "other_func"
|
||||
net.is_fixpoint = FixTypes.other_func
|
||||
|
||||
return net
|
||||
|
@ -1,203 +0,0 @@
|
||||
import copy
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
import random
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from functionalities_test import is_identity_function, test_status
|
||||
from journal_basins import SpawnExperiment, mean_invariate_manhattan_distance
|
||||
from network import Net
|
||||
|
||||
from sklearn.metrics import mean_absolute_error as MAE
|
||||
from sklearn.metrics import mean_squared_error as MSE
|
||||
|
||||
|
||||
class SpawnLinspaceExperiment(SpawnExperiment):
|
||||
|
||||
def spawn_and_continue(self, number_clones: int = None):
|
||||
number_clones = number_clones or self.nr_clones
|
||||
|
||||
df = pd.DataFrame(
|
||||
columns=['clone', 'parent', 'parent2',
|
||||
'MAE_pre', 'MAE_post',
|
||||
'MSE_pre', 'MSE_post',
|
||||
'MIM_pre', 'MIM_post',
|
||||
'noise', 'status_pst'])
|
||||
|
||||
# For every initial net {i} after populating (that is fixpoint after first epoch);
|
||||
# parent = self.parents[0]
|
||||
# parent_clone = clone = Net(parent.input_size, parent.hidden_size, parent.out_size,
|
||||
# name=f"{parent.name}_clone_{0}", start_time=self.ST_steps)
|
||||
# parent_clone.apply_weights(torch.as_tensor(parent.create_target_weights(parent.input_weight_matrix())))
|
||||
# parent_clone = parent_clone.apply_noise(self.noise)
|
||||
# self.parents.append(parent_clone)
|
||||
pairwise_net_list = list(itertools.combinations(self.parents, 2))
|
||||
for net1, net2 in pairwise_net_list:
|
||||
# We set parent start_time to just before this epoch ended, so plotting is zoomed in. Comment out to
|
||||
# to see full trajectory (but the clones will be very hard to see).
|
||||
# Make one target to compare distances to clones later when they have trained.
|
||||
net1.start_time = self.ST_steps - 150
|
||||
net1_input_data = net1.input_weight_matrix().detach()
|
||||
net1_target_data = net1.create_target_weights(net1_input_data).detach()
|
||||
|
||||
net2.start_time = self.ST_steps - 150
|
||||
net2_input_data = net2.input_weight_matrix().detach()
|
||||
net2_target_data = net2.create_target_weights(net2_input_data).detach()
|
||||
|
||||
if is_identity_function(net1) and is_identity_function(net2):
|
||||
# if True:
|
||||
# Clone the fixpoint x times and add (+-)self.noise to weight-sets randomly;
|
||||
# To plot clones starting after first epoch (z=ST_steps), set that as start_time!
|
||||
# To make sure PCA will plot the same trajectory up until this point, we clone the
|
||||
# parent-net's weight history as well.
|
||||
|
||||
in_between_weights = np.linspace(net1_target_data, net2_target_data, number_clones, endpoint=False)
|
||||
# in_between_weights = np.logspace(net1_target_data, net2_target_data, number_clones, endpoint=False)
|
||||
|
||||
for j, in_between_weight in enumerate(in_between_weights):
|
||||
clone = Net(net1.input_size, net1.hidden_size, net1.out_size,
|
||||
name=f"{net1.name}_{net2.name}_clone_{str(j)}", start_time=self.ST_steps + 100)
|
||||
clone.apply_weights(torch.as_tensor(in_between_weight))
|
||||
|
||||
clone.s_train_weights_history = copy.deepcopy(net1.s_train_weights_history)
|
||||
clone.number_trained = copy.deepcopy(net1.number_trained)
|
||||
|
||||
# Pre Training distances (after noise application of course)
|
||||
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
|
||||
MAE_pre = MAE(net1_target_data, clone_pre_weights)
|
||||
MSE_pre = MSE(net1_target_data, clone_pre_weights)
|
||||
MIM_pre = mean_invariate_manhattan_distance(net1_target_data, clone_pre_weights)
|
||||
|
||||
try:
|
||||
# Then finish training each clone {j} (for remaining epoch-1 * ST_steps) ..
|
||||
for _ in range(self.epochs - 1):
|
||||
for _ in range(self.ST_steps):
|
||||
clone.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
if any([torch.isnan(x).any() for x in clone.parameters()]):
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
print("Ran into nan in 'in beetween weights' array.")
|
||||
df.loc[len(df)] = [j, net1.name, net2.name,
|
||||
MAE_pre, 0,
|
||||
MSE_pre, 0,
|
||||
MIM_pre, 0,
|
||||
self.noise, clone.is_fixpoint]
|
||||
continue
|
||||
|
||||
# Post Training distances for comparison
|
||||
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
|
||||
MAE_post = MAE(net1_target_data, clone_post_weights)
|
||||
MSE_post = MSE(net1_target_data, clone_post_weights)
|
||||
MIM_post = mean_invariate_manhattan_distance(net1_target_data, clone_post_weights)
|
||||
|
||||
# .. log to data-frame and add to nets for 3d plotting if they are fixpoints themselves.
|
||||
test_status(clone)
|
||||
if is_identity_function(clone):
|
||||
print(f"Clone {j} (between {net1.name} and {net2.name}) is fixpoint."
|
||||
f"\nMSE({net1.name},{j}): {MSE_post}"
|
||||
f"\nMAE({net1.name},{j}): {MAE_post}"
|
||||
f"\nMIM({net1.name},{j}): {MIM_post}\n")
|
||||
self.nets.append(clone)
|
||||
|
||||
df.loc[len(df)] = [j, net1.name, net2.name,
|
||||
MAE_pre, MAE_post,
|
||||
MSE_pre, MSE_post,
|
||||
MIM_pre, MIM_post,
|
||||
self.noise, clone.is_fixpoint]
|
||||
|
||||
for net1, net2 in pairwise_net_list:
|
||||
try:
|
||||
value = 'MAE'
|
||||
c_selector = [f'{value}_pre', f'{value}_post']
|
||||
values = df.loc[(df['parent'] == net1.name) & (df['parent2'] == net2.name)][c_selector]
|
||||
this_min, this_max = values.values.min(), values.values.max()
|
||||
df.loc[(df['parent'] == net1.name) &
|
||||
(df['parent2'] == net2.name), c_selector] = (values - this_min) / (this_max - this_min)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for parent in self.parents:
|
||||
for _ in range(self.epochs - 1):
|
||||
for _ in range(self.ST_steps):
|
||||
parent.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
self.df = df
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
NET_INPUT_SIZE = 4
|
||||
NET_OUT_SIZE = 1
|
||||
|
||||
# Define number of runs & name:
|
||||
ST_runs = 1
|
||||
ST_runs_name = "test-27"
|
||||
ST_steps = 2000
|
||||
ST_epochs = 2
|
||||
ST_log_step_size = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
nr_clones = 25
|
||||
ST_population_size = 10
|
||||
ST_net_hidden_size = 2
|
||||
ST_net_learning_rate = 0.04
|
||||
ST_name_hash = random.getrandbits(32)
|
||||
|
||||
print(f"Running the Spawn experiment:")
|
||||
exp = SpawnLinspaceExperiment(
|
||||
population_size=ST_population_size,
|
||||
log_step_size=ST_log_step_size,
|
||||
net_input_size=NET_INPUT_SIZE,
|
||||
net_hidden_size=ST_net_hidden_size,
|
||||
net_out_size=NET_OUT_SIZE,
|
||||
net_learning_rate=ST_net_learning_rate,
|
||||
epochs=ST_epochs,
|
||||
st_steps=ST_steps,
|
||||
nr_clones=nr_clones,
|
||||
noise=1e-8,
|
||||
directory=Path('output') / 'spawn_basin' / f'{ST_name_hash}' / f'linage'
|
||||
)
|
||||
df = exp.df
|
||||
|
||||
directory = Path('output') / 'spawn_basin' / f'{ST_name_hash}' / 'linage'
|
||||
with (directory / f"experiment_pickle_{ST_name_hash}.p").open('wb') as f:
|
||||
pickle.dump(exp, f)
|
||||
print(f"\nSaved experiment to {directory}.")
|
||||
|
||||
# Boxplot with counts of nr_fixpoints, nr_other, nr_etc. on y-axis
|
||||
# sns.countplot(data=df, x="noise", hue="status_post")
|
||||
# plt.savefig(f"output/spawn_basin/{ST_name_hash}/fixpoint_status_countplot.png")
|
||||
|
||||
# Catplot (either kind="point" or "box") that shows before-after training distances to parent
|
||||
# mlt = df[["MIM_pre", "MIM_post", "noise"]].melt("noise", var_name="time", value_name='Average Distance')
|
||||
# sns.catplot(data=mlt, x="time", y="Average Distance", col="noise", kind="point", col_wrap=5, sharey=False)
|
||||
# plt.savefig(f"output/spawn_basin/{ST_name_hash}/clone_distance_catplot.png")
|
||||
|
||||
# Pointplot with pre and after parent Distances
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt, ticker
|
||||
|
||||
# ptplt = sns.pointplot(data=exp.df, x='MAE_pre', y='MAE_post', join=False)
|
||||
ptplt = sns.scatterplot(x=exp.df['MAE_pre'], y=exp.df['MAE_post'])
|
||||
# ptplt.set(xscale='log', yscale='log')
|
||||
x0, x1 = ptplt.axes.get_xlim()
|
||||
y0, y1 = ptplt.axes.get_ylim()
|
||||
lims = [max(x0, y0), min(x1, y1)]
|
||||
# This is the x=y line using transforms
|
||||
ptplt.plot(lims, lims, 'w', linestyle='dashdot', transform=ptplt.axes.transData)
|
||||
ptplt.plot([0, 1], [0, 1], ':k', transform=ptplt.axes.transAxes)
|
||||
ptplt.set(xlabel='Mean Absolute Distance before Self-Training',
|
||||
ylabel='Mean Absolute Distance after Self-Training')
|
||||
# ptplt.axes.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: round(float(x), 2)))
|
||||
# ptplt.xticks(rotation=45)
|
||||
#for ind, label in enumerate(ptplt.get_xticklabels()):
|
||||
# if ind % 10 == 0: # every 10th label is kept
|
||||
# label.set_visible(True)
|
||||
# else:
|
||||
# label.set_visible(False)
|
||||
|
||||
filepath = exp.directory / 'mim_dist_plot.pdf'
|
||||
plt.tight_layout()
|
||||
plt.savefig(filepath, dpi=600, format='pdf', bbox_inches='tight')
|
@ -1,315 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import copy
|
||||
from functionalities_test import is_identity_function, test_status
|
||||
from network import Net
|
||||
from visualization import plot_3d_self_train, plot_loss
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
from sklearn.metrics import mean_absolute_error as MAE
|
||||
from sklearn.metrics import mean_squared_error as MSE
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
def l1(tup):
|
||||
a, b = tup
|
||||
return abs(a - b)
|
||||
|
||||
|
||||
def mean_invariate_manhattan_distance(x, y):
|
||||
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
|
||||
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
|
||||
# Idea was to find weight sets that have same values but just in different positions, that would
|
||||
# make this distance 0.
|
||||
try:
|
||||
return np.mean(list(map(l1, zip(sorted(x.detach().numpy()), sorted(y.detach().numpy())))))
|
||||
except AttributeError:
|
||||
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
|
||||
|
||||
|
||||
def distance_matrix(nets, distance="MIM", print_it=True):
|
||||
matrix = [[0 for _ in range(len(nets))] for _ in range(len(nets))]
|
||||
for net in range(len(nets)):
|
||||
weights = nets[net].input_weight_matrix()[:, 0]
|
||||
for other_net in range(len(nets)):
|
||||
other_weights = nets[other_net].input_weight_matrix()[:, 0]
|
||||
if distance in ["MSE"]:
|
||||
matrix[net][other_net] = MSE(weights, other_weights)
|
||||
elif distance in ["MAE"]:
|
||||
matrix[net][other_net] = MAE(weights, other_weights)
|
||||
elif distance in ["MIM"]:
|
||||
matrix[net][other_net] = mean_invariate_manhattan_distance(weights, other_weights)
|
||||
|
||||
if print_it:
|
||||
print(f"\nDistance matrix (all to all) [{distance}]:")
|
||||
headers = [i.name for i in nets]
|
||||
print(tabulate(matrix, showindex=headers, headers=headers, tablefmt='orgtbl'))
|
||||
return matrix
|
||||
|
||||
|
||||
def distance_from_parent(nets, distance="MIM", print_it=True):
|
||||
list_of_matrices = []
|
||||
parents = list(filter(lambda x: "clone" not in x.name and is_identity_function(x), nets))
|
||||
distance_range = range(10)
|
||||
for parent in parents:
|
||||
parent_weights = parent.create_target_weights(parent.input_weight_matrix())
|
||||
clones = list(filter(lambda y: parent.name in y.name and parent.name != y.name, nets))
|
||||
matrix = [[0 for _ in distance_range] for _ in range(len(clones))]
|
||||
|
||||
for dist in distance_range:
|
||||
for idx, clone in enumerate(clones):
|
||||
clone_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
if distance in ["MSE"]:
|
||||
matrix[idx][dist] = MSE(parent_weights, clone_weights) < pow(10, -dist)
|
||||
elif distance in ["MAE"]:
|
||||
matrix[idx][dist] = MAE(parent_weights, clone_weights) < pow(10, -dist)
|
||||
elif distance in ["MIM"]:
|
||||
matrix[idx][dist] = mean_invariate_manhattan_distance(parent_weights, clone_weights) < pow(10,
|
||||
-dist)
|
||||
|
||||
if print_it:
|
||||
print(f"\nDistances from parent {parent.name} [{distance}]:")
|
||||
col_headers = [str(f"10e-{d}") for d in distance_range]
|
||||
row_headers = [str(f"clone_{i}") for i in range(len(clones))]
|
||||
print(tabulate(matrix, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||
|
||||
list_of_matrices.append(matrix)
|
||||
|
||||
return list_of_matrices
|
||||
|
||||
|
||||
class SpawnExperiment:
|
||||
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, st_steps, nr_clones, noise, directory) -> None:
|
||||
self.population_size = population_size
|
||||
self.log_step_size = log_step_size
|
||||
self.net_input_size = net_input_size
|
||||
self.net_hidden_size = net_hidden_size
|
||||
self.net_out_size = net_out_size
|
||||
self.net_learning_rate = net_learning_rate
|
||||
self.epochs = epochs
|
||||
self.ST_steps = st_steps
|
||||
self.loss_history = []
|
||||
self.nets = []
|
||||
self.nr_clones = nr_clones
|
||||
self.noise = noise or 10e-5
|
||||
print("\nNOISE:", self.noise)
|
||||
|
||||
self.parents = []
|
||||
|
||||
self.directory = Path(directory)
|
||||
self.directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.populate_environment()
|
||||
self.spawn_and_continue()
|
||||
self.weights_evolution_3d_experiment()
|
||||
# self.visualize_loss()
|
||||
self.distance_matrix = distance_matrix(self.nets, print_it=False)
|
||||
self.parent_clone_distances = distance_from_parent(self.nets, print_it=False)
|
||||
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating experiment %s" % i)
|
||||
|
||||
net_name = f"ST_net_{str(i)}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
self.nets.append(net)
|
||||
self.parents.append(net)
|
||||
|
||||
def spawn_and_continue(self, number_clones: int = None):
|
||||
number_clones = number_clones or self.nr_clones
|
||||
|
||||
df = pd.DataFrame(
|
||||
columns=['name', 'MAE_pre', 'MAE_post', 'MSE_pre', 'MSE_post', 'MIM_pre', 'MIM_post', 'noise',
|
||||
'status_post'])
|
||||
|
||||
# For every initial net {i} after populating (that is fixpoint after first epoch);
|
||||
for i in range(self.population_size):
|
||||
net = self.nets[i]
|
||||
# We set parent start_time to just before this epoch ended, so plotting is zoomed in. Comment out to
|
||||
# to see full trajectory (but the clones will be very hard to see).
|
||||
# Make one target to compare distances to clones later when they have trained.
|
||||
net.start_time = self.ST_steps - 350
|
||||
net_input_data = net.input_weight_matrix()
|
||||
net_target_data = net.create_target_weights(net_input_data)
|
||||
|
||||
if is_identity_function(net):
|
||||
print(f"\nNet {i} is fixpoint")
|
||||
|
||||
# Clone the fixpoint x times and add (+-)self.noise to weight-sets randomly;
|
||||
# To plot clones starting after first epoch (z=ST_steps), set that as start_time!
|
||||
# To make sure PCA will plot the same trajectory up until this point, we clone the
|
||||
# parent-net's weight history as well.
|
||||
for j in range(number_clones):
|
||||
clone = Net(net.input_size, net.hidden_size, net.out_size,
|
||||
f"ST_net_{str(i)}_clone_{str(j)}", start_time=self.ST_steps)
|
||||
clone.load_state_dict(copy.deepcopy(net.state_dict()))
|
||||
rand_noise = prng() * self.noise
|
||||
clone = clone.apply_noise(rand_noise)
|
||||
clone.s_train_weights_history = copy.deepcopy(net.s_train_weights_history)
|
||||
clone.number_trained = copy.deepcopy(net.number_trained)
|
||||
|
||||
# Pre Training distances (after noise application of course)
|
||||
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
MAE_pre = MAE(net_target_data, clone_pre_weights)
|
||||
MSE_pre = MSE(net_target_data, clone_pre_weights)
|
||||
MIM_pre = mean_invariate_manhattan_distance(net_target_data, clone_pre_weights)
|
||||
|
||||
# Then finish training each clone {j} (for remaining epoch-1 * ST_steps) ..
|
||||
for _ in range(self.epochs - 1):
|
||||
for _ in range(self.ST_steps):
|
||||
clone.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
# Post Training distances for comparison
|
||||
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
MAE_post = MAE(net_target_data, clone_post_weights)
|
||||
MSE_post = MSE(net_target_data, clone_post_weights)
|
||||
MIM_post = mean_invariate_manhattan_distance(net_target_data, clone_post_weights)
|
||||
|
||||
# .. log to data-frame and add to nets for 3d plotting if they are fixpoints themselves.
|
||||
test_status(clone)
|
||||
if is_identity_function(clone):
|
||||
print(f"Clone {j} (of net_{i}) is fixpoint."
|
||||
f"\nMSE({i},{j}): {MSE_post}"
|
||||
f"\nMAE({i},{j}): {MAE_post}"
|
||||
f"\nMIM({i},{j}): {MIM_post}\n")
|
||||
self.nets.append(clone)
|
||||
|
||||
df.loc[clone.name] = [clone.name, MAE_pre, MAE_post, MSE_pre, MSE_post, MIM_pre, MIM_post, self.noise, clone.is_fixpoint]
|
||||
|
||||
# Finally take parent net {i} and finish it's training for comparison to clone development.
|
||||
for _ in range(self.epochs - 1):
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
net_weights_after = net.create_target_weights(net.input_weight_matrix())
|
||||
print(f"Parent net's distance to original position."
|
||||
f"\nMSE(OG,new): {MAE(net_target_data, net_weights_after)}"
|
||||
f"\nMAE(OG,new): {MSE(net_target_data, net_weights_after)}"
|
||||
f"\nMIM(OG,new): {mean_invariate_manhattan_distance(net_target_data, net_weights_after)}\n")
|
||||
|
||||
self.df = df
|
||||
|
||||
def weights_evolution_3d_experiment(self):
|
||||
exp_name = f"ST_{str(len(self.nets))}_nets_3d_weights_PCA"
|
||||
return plot_3d_self_train(self.nets, exp_name, self.directory, self.log_step_size, plot_pca_together=True)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.nets)):
|
||||
net_loss_history = self.nets[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
plot_loss(self.loss_history, self.directory)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
NET_INPUT_SIZE = 4
|
||||
NET_OUT_SIZE = 1
|
||||
|
||||
# Define number of runs & name:
|
||||
ST_runs = 1
|
||||
ST_runs_name = "test-27"
|
||||
ST_steps = 2500
|
||||
ST_epochs = 2
|
||||
ST_log_step_size = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
nr_clones = 10
|
||||
ST_population_size = 1
|
||||
ST_net_hidden_size = 2
|
||||
ST_net_learning_rate = 0.04
|
||||
ST_name_hash = random.getrandbits(32)
|
||||
|
||||
print(f"Running the Spawn experiment:")
|
||||
exp_list = []
|
||||
for noise_factor in range(2, 3):
|
||||
exp = SpawnExperiment(
|
||||
population_size=ST_population_size,
|
||||
log_step_size=ST_log_step_size,
|
||||
net_input_size=NET_INPUT_SIZE,
|
||||
net_hidden_size=ST_net_hidden_size,
|
||||
net_out_size=NET_OUT_SIZE,
|
||||
net_learning_rate=ST_net_learning_rate,
|
||||
epochs=ST_epochs,
|
||||
st_steps=ST_steps,
|
||||
nr_clones=nr_clones,
|
||||
noise=pow(10, -noise_factor),
|
||||
directory=Path('output') / 'spawn_basin' / f'{ST_name_hash}' / f'10e-{noise_factor}'
|
||||
)
|
||||
exp_list.append(exp)
|
||||
|
||||
directory = Path('output') / 'spawn_basin' / f'{ST_name_hash}'
|
||||
pickle.dump(exp_list, open(f"{directory}/experiment_pickle_{ST_name_hash}.p", "wb"))
|
||||
print(f"\nSaved experiment to {directory}.")
|
||||
|
||||
# Concat all dataframes, and add columns depending on where clone weights end up after training (rel. to parent)
|
||||
df = pd.concat([exp.df for exp in exp_list])
|
||||
df = df.dropna().reset_index()
|
||||
df["relative_distance"] = [ (df.loc[i]["MAE_pre"] - df.loc[i]["MAE_post"])/df.loc[i]["noise"] for i in range(len(df))]
|
||||
df["class"] = [ "approaching" if df.loc[i]["relative_distance"] > 0 else "distancing" if df.loc[i]["relative_distance"] < 0 else "stationary" for i in range(len(df))]
|
||||
|
||||
# Countplot of all fixpoint clone after training per class.
|
||||
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27, legend=False)
|
||||
ax.set_axis_labels("Noise Levels", "Clone Fixpoints After Training Count ", fontsize=15)
|
||||
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
|
||||
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
|
||||
plt.legend(fontsize='large')
|
||||
plt.savefig(f"{directory}/clone_status_after_countplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
# Catplot of before-after comparison of the clone's weights. Colors links depending on class (approaching, distancing, stationary (i.e., MAE=0)). Blue, orange and green are based on countplot above, should be save for colorblindness (see https://gist.github.com/mwaskom/b35f6ebc2d4b340b4f64a4e28e778486)-
|
||||
mlt = df.melt(id_vars=["name", "noise", "class"], value_vars=["MAE_pre", "MAE_post"], var_name="State", value_name="Distance")
|
||||
P = ["blue" if mlt.loc[i]["class"] == "approaching" else "orange" if mlt.loc[i]["class"] == "distancing" else "green" for i in range(len(mlt))]
|
||||
P = sns.color_palette(P, as_cmap=False)
|
||||
ax = sns.catplot(data=mlt, x="State", y="Distance", col="noise", hue="name", kind="point", palette=P, col_wrap=min(5, len(exp_list)), sharey=False, legend=False)
|
||||
ax.map(sns.boxplot, "State", "Distance", "noise", linewidth=0.8, order=["MAE_pre", "MAE_post"], whis=[0, 100])
|
||||
ax.set_axis_labels("", "Manhattan Distance To Parent Weights", fontsize=15)
|
||||
ax.set_xticklabels(labels=('after noise application', 'after training'), fontsize=15)
|
||||
# plt.ticklabel_format(style='sci', axis='x')
|
||||
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
# Catplot of child_nets L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
|
||||
df_acc = pd.DataFrame(columns=["name", "noise", "l1_acc", "Network Type"])
|
||||
for i in range(len(exp_list)):
|
||||
noise = exp_list[i].noise
|
||||
print(f"\nNoise: {noise}")
|
||||
for network in exp_list[i].nets:
|
||||
is_parent = "clone" not in network.name
|
||||
if is_parent:
|
||||
network.apply_weights(torch.tensor(network.s_train_weights_history[int(len(network.s_train_weights_history)/2)][0]))
|
||||
input_data = network.input_weight_matrix()
|
||||
target_data = network.create_target_weights(input_data)
|
||||
predicted_values = network(input_data)
|
||||
mse_loss = F.mse_loss(target_data, predicted_values).item()
|
||||
l1_loss = F.l1_loss(target_data, predicted_values).item()
|
||||
|
||||
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "child_nets"]
|
||||
print("MSE:", mse_loss, "\t", "L1: ", l1_loss, "\t", network.name)
|
||||
|
||||
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and child_nets accuracy is too far apart this plot might not work (only shows either parents or part of the child_nets).
|
||||
ax = sns.catplot(data=df_acc, y="l1_acc", x="noise", hue="Network Type", kind="box", legend=False, showfliers=False, height=5.27, aspect=11.7/5.27, sharey=False)
|
||||
ax.map(plt.axhline, y=10**-6, ls='--')
|
||||
ax.map(plt.axhline, y=10**-7, ls='--')
|
||||
ax.set_axis_labels("Noise levels", "L1 Prediction Loss After Training", fontsize=15)
|
||||
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
|
||||
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
|
||||
plt.legend(fontsize='large')
|
||||
plt.savefig(f"{directory}/parent_vs_children_accuracy_{ST_name_hash}.png")
|
||||
plt.clf()
|
@ -1,246 +0,0 @@
|
||||
import pickle
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import random
|
||||
import copy
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from matplotlib.ticker import ScalarFormatter
|
||||
from tqdm import tqdm
|
||||
from tabulate import tabulate
|
||||
|
||||
from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent
|
||||
from network import Net
|
||||
from torch.nn import functional as F
|
||||
from visualization import plot_loss, bar_chart_fixpoints
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
def generate_perfekt_synthetic_fixpoint_weights():
|
||||
return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0],
|
||||
[1.0], [0.0], [0.0], [0.0],
|
||||
[1.0], [0.0]
|
||||
], dtype=torch.float32)
|
||||
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
|
||||
class RobustnessComparisonExperiment:
|
||||
|
||||
@staticmethod
|
||||
def apply_noise(network, noise: int):
|
||||
# Changing the weights of a network to values + noise
|
||||
for layer_id, layer_name in enumerate(network.state_dict()):
|
||||
for line_id, line_values in enumerate(network.state_dict()[layer_name]):
|
||||
for weight_id, weight_value in enumerate(network.state_dict()[layer_name][line_id]):
|
||||
# network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
|
||||
if prng() < 0.5:
|
||||
network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
|
||||
else:
|
||||
network.state_dict()[layer_name][line_id][weight_id] = weight_value - noise
|
||||
|
||||
return network
|
||||
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, st_steps, synthetic, directory) -> None:
|
||||
self.population_size = population_size
|
||||
self.log_step_size = log_step_size
|
||||
self.net_input_size = net_input_size
|
||||
self.net_hidden_size = net_hidden_size
|
||||
self.net_out_size = net_out_size
|
||||
self.net_learning_rate = net_learning_rate
|
||||
self.epochs = epochs
|
||||
self.ST_steps = st_steps
|
||||
self.loss_history = []
|
||||
self.is_synthetic = synthetic
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
self.directory = Path(directory)
|
||||
self.directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.id_functions = []
|
||||
self.nets = self.populate_environment()
|
||||
self.count_fixpoints()
|
||||
self.time_to_vergence, self.time_as_fixpoint = self.test_robustness(
|
||||
seeds=population_size if self.is_synthetic else 1)
|
||||
|
||||
def populate_environment(self):
|
||||
nets = []
|
||||
if self.is_synthetic:
|
||||
''' Either use perfect / hand-constructed fixpoint ... '''
|
||||
net_name = f"net_{str(0)}_synthetic"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
net.apply_weights(generate_perfekt_synthetic_fixpoint_weights())
|
||||
nets.append(net)
|
||||
|
||||
else:
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating experiment %s" % i)
|
||||
|
||||
''' .. or use natural approach to train fixpoints from random initialisation. '''
|
||||
net_name = f"net_{str(i)}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
for _ in range(self.epochs):
|
||||
net.self_train(self.ST_steps, self.log_step_size, self.net_learning_rate)
|
||||
nets.append(net)
|
||||
return nets
|
||||
|
||||
def test_robustness(self, print_it=True, noise_levels=10, seeds=10):
|
||||
assert (len(self.id_functions) == 1 and seeds > 1) or (len(self.id_functions) > 1 and seeds == 1)
|
||||
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in
|
||||
range(seeds if self.is_synthetic else len(self.id_functions))]
|
||||
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in
|
||||
range(seeds if self.is_synthetic else len(self.id_functions))]
|
||||
row_headers = []
|
||||
|
||||
# This checks wether to use synthetic setting with multiple seeds
|
||||
# or multi network settings with a singlee seed
|
||||
|
||||
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
|
||||
'Time to convergence', 'Time as fixpoint'])
|
||||
with tqdm(total=max(len(self.id_functions), seeds)) as pbar:
|
||||
for i, fixpoint in enumerate(self.id_functions): # 1 / n
|
||||
row_headers.append(fixpoint.name)
|
||||
for seed in range(seeds): # n / 1
|
||||
setting = seed if self.is_synthetic else i
|
||||
|
||||
for noise_level in range(noise_levels):
|
||||
steps = 0
|
||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||
f"{fixpoint.name}_clone_noise_1e-{noise_level}")
|
||||
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||
clone = clone.apply_noise(pow(10, -noise_level))
|
||||
|
||||
while not is_zero_fixpoint(clone) and not is_divergent(clone):
|
||||
# -> before
|
||||
clone_weight_pre_application = clone.input_weight_matrix()
|
||||
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
|
||||
|
||||
clone.self_application(1, self.log_step_size)
|
||||
time_to_vergence[setting][noise_level] += 1
|
||||
# -> after
|
||||
clone_weight_post_application = clone.input_weight_matrix()
|
||||
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
||||
|
||||
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
|
||||
|
||||
if is_identity_function(clone):
|
||||
time_as_fixpoint[setting][noise_level] += 1
|
||||
# When this raises a Type Error, we found a second order fixpoint!
|
||||
steps += 1
|
||||
|
||||
df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
steps, absolute_loss,
|
||||
time_to_vergence[setting][noise_level],
|
||||
time_as_fixpoint[setting][noise_level]]
|
||||
pbar.update(1)
|
||||
|
||||
# Get the measuremts at the highest time_time_to_vergence
|
||||
df_sorted = df.sort_values('Self Train Steps', ascending=False).drop_duplicates(['setting', 'Noise Level'])
|
||||
df_melted = df_sorted.reset_index().melt(id_vars=['setting', 'Noise Level', 'Self Train Steps'],
|
||||
value_vars=['Time to convergence', 'Time as fixpoint'],
|
||||
var_name="Measurement",
|
||||
value_name="Steps").sort_values('Noise Level')
|
||||
# Plotting
|
||||
# plt.rcParams.update({
|
||||
# "text.usetex": True,
|
||||
# "font.family": "sans-serif",
|
||||
# "font.size": 12,
|
||||
# "font.weight": 'bold',
|
||||
# "font.sans-serif": ["Helvetica"]})
|
||||
sns.set(style='whitegrid', font_scale=2)
|
||||
bf = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
|
||||
synthetic = 'synthetic' if self.is_synthetic else 'natural'
|
||||
plt.tight_layout()
|
||||
|
||||
# sns.set(rc={'figure.figsize': (10, 50)})
|
||||
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
|
||||
# col='noise_level', col_wrap=3, showfliers=False)
|
||||
|
||||
filename = f"absolute_loss_perapplication_boxplot_grid_{'synthetic' if self.is_synthetic else 'wild'}.png"
|
||||
filepath = self.directory / filename
|
||||
plt.savefig(str(filepath))
|
||||
|
||||
if print_it:
|
||||
col_headers = [str(f"1e-{d}") for d in range(noise_levels)]
|
||||
|
||||
print(f"\nAppplications steps until divergence / zero: ")
|
||||
# print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||
|
||||
print(f"\nTime as fixpoint: ")
|
||||
# print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||
return time_as_fixpoint, time_to_vergence
|
||||
|
||||
def count_fixpoints(self):
|
||||
exp_details = f"ST steps: {self.ST_steps}"
|
||||
self.id_functions = test_for_fixpoints(self.fixpoint_counters, self.nets)
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.nets)):
|
||||
net_loss_history = self.nets[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
plot_loss(self.loss_history, self.directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NET_INPUT_SIZE = 4
|
||||
NET_OUT_SIZE = 1
|
||||
|
||||
ST_steps = 1000
|
||||
ST_epochs = 5
|
||||
ST_log_step_size = 10
|
||||
ST_population_size = 1000
|
||||
ST_net_hidden_size = 2
|
||||
ST_net_learning_rate = 0.004
|
||||
ST_name_hash = random.getrandbits(32)
|
||||
ST_synthetic = False
|
||||
|
||||
print(f"Running the robustness comparison experiment:")
|
||||
exp = RobustnessComparisonExperiment(
|
||||
population_size=ST_population_size,
|
||||
log_step_size=ST_log_step_size,
|
||||
net_input_size=NET_INPUT_SIZE,
|
||||
net_hidden_size=ST_net_hidden_size,
|
||||
net_out_size=NET_OUT_SIZE,
|
||||
net_learning_rate=ST_net_learning_rate,
|
||||
epochs=ST_epochs,
|
||||
st_steps=ST_steps,
|
||||
synthetic=ST_synthetic,
|
||||
directory=Path('output') / 'journal_robustness' / f'{ST_name_hash}'
|
||||
)
|
||||
|
||||
directory = Path('output') / 'journal_robustness' / f'{ST_name_hash}'
|
||||
pickle.dump(exp, open(f"{directory}/experiment_pickle_{ST_name_hash}.p", "wb"))
|
||||
print(f"\nSaved experiment to {directory}.")
|
@ -1,341 +0,0 @@
|
||||
import pickle
|
||||
|
||||
import random
|
||||
import copy
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.metrics import mean_absolute_error as MAE
|
||||
from sklearn.metrics import mean_squared_error as MSE
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
|
||||
from functionalities_test import is_identity_function, test_status, is_zero_fixpoint, is_divergent, \
|
||||
is_secondary_fixpoint
|
||||
from journal_basins import mean_invariate_manhattan_distance
|
||||
from network import Net
|
||||
from visualization import plot_loss, plot_3d_soup
|
||||
|
||||
|
||||
def l1(tup):
|
||||
a, b = tup
|
||||
return abs(a - b)
|
||||
|
||||
|
||||
def distance_matrix(nets, distance="MIM", print_it=True):
|
||||
matrix = [[0 for _ in range(len(nets))] for _ in range(len(nets))]
|
||||
for net in range(len(nets)):
|
||||
weights = nets[net].input_weight_matrix()[:, 0]
|
||||
for other_net in range(len(nets)):
|
||||
other_weights = nets[other_net].input_weight_matrix()[:, 0]
|
||||
if distance in ["MSE"]:
|
||||
matrix[net][other_net] = MSE(weights, other_weights)
|
||||
elif distance in ["MAE"]:
|
||||
matrix[net][other_net] = MAE(weights, other_weights)
|
||||
elif distance in ["MIM"]:
|
||||
matrix[net][other_net] = mean_invariate_manhattan_distance(weights, other_weights)
|
||||
|
||||
if print_it:
|
||||
print(f"\nDistance matrix (all to all) [{distance}]:")
|
||||
headers = [i.name for i in nets]
|
||||
print(tabulate(matrix, showindex=headers, headers=headers, tablefmt='orgtbl'))
|
||||
return matrix
|
||||
|
||||
|
||||
def distance_from_parent(nets, distance="MIM", print_it=True):
|
||||
list_of_matrices = []
|
||||
parents = list(filter(lambda x: "clone" not in x.name and is_identity_function(x), nets))
|
||||
distance_range = range(10)
|
||||
for parent in parents:
|
||||
parent_weights = parent.create_target_weights(parent.input_weight_matrix())
|
||||
clones = list(filter(lambda y: parent.name in y.name and parent.name != y.name, nets))
|
||||
matrix = [[0 for _ in distance_range] for _ in range(len(clones))]
|
||||
|
||||
for dist in distance_range:
|
||||
for idx, clone in enumerate(clones):
|
||||
clone_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
if distance in ["MSE"]:
|
||||
matrix[idx][dist] = MSE(parent_weights, clone_weights) < pow(10, -dist)
|
||||
elif distance in ["MAE"]:
|
||||
matrix[idx][dist] = MAE(parent_weights, clone_weights) < pow(10, -dist)
|
||||
elif distance in ["MIM"]:
|
||||
matrix[idx][dist] = mean_invariate_manhattan_distance(parent_weights, clone_weights) < pow(10,
|
||||
-dist)
|
||||
|
||||
if print_it:
|
||||
print(f"\nDistances from parent {parent.name} [{distance}]:")
|
||||
col_headers = [str(f"10e-{d}") for d in distance_range]
|
||||
row_headers = [str(f"clone_{i}") for i in range(len(clones))]
|
||||
print(tabulate(matrix, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||
|
||||
list_of_matrices.append(matrix)
|
||||
|
||||
return list_of_matrices
|
||||
|
||||
|
||||
class SoupSpawnExperiment:
|
||||
|
||||
|
||||
def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, st_steps, attack_chance, nr_clones, noise, directory) -> None:
|
||||
self.population_size = population_size
|
||||
self.log_step_size = log_step_size
|
||||
self.net_input_size = net_input_size
|
||||
self.net_hidden_size = net_hidden_size
|
||||
self.net_out_size = net_out_size
|
||||
self.net_learning_rate = net_learning_rate
|
||||
self.epochs = epochs
|
||||
self.ST_steps = st_steps
|
||||
self.attack_chance = attack_chance
|
||||
self.loss_history = []
|
||||
self.nr_clones = nr_clones
|
||||
self.noise = noise or 10e-5
|
||||
print("\nNOISE:", self.noise)
|
||||
|
||||
self.directory = Path(directory)
|
||||
self.directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Populating environment & evolving entities
|
||||
self.parents = []
|
||||
self.clones = []
|
||||
self.parents_with_clones = []
|
||||
self.parents_clones_id_functions = []
|
||||
|
||||
self.populate_environment()
|
||||
|
||||
self.spawn_and_continue()
|
||||
# self.weights_evolution_3d_experiment(self.parents, "only_parents")
|
||||
self.weights_evolution_3d_experiment(self.clones, "only_clones")
|
||||
self.weights_evolution_3d_experiment(self.parents_with_clones, "parents_with_clones")
|
||||
# self.weights_evolution_3d_experiment(self.parents_clones_id_functions, "id_f_with_parents")
|
||||
|
||||
# self.visualize_loss()
|
||||
self.distance_matrix = distance_matrix(self.parents_clones_id_functions, print_it=False)
|
||||
self.parent_clone_distances = distance_from_parent(self.parents_clones_id_functions, print_it=False)
|
||||
|
||||
# self.save()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in loop_population_size:
|
||||
loop_population_size.set_description("Populating experiment %s" % i)
|
||||
|
||||
net_name = f"parent_net_{str(i)}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
self.parents.append(net)
|
||||
self.parents_with_clones.append(net)
|
||||
|
||||
if is_identity_function(net):
|
||||
self.parents_clones_id_functions.append(net)
|
||||
print(f"\nNet {net.name} is identity function")
|
||||
|
||||
if is_divergent(net):
|
||||
print(f"\nNet {net.name} is divergent")
|
||||
|
||||
if is_zero_fixpoint(net):
|
||||
print(f"\nNet {net.name} is zero fixpoint")
|
||||
|
||||
if is_secondary_fixpoint(net):
|
||||
print(f"\nNet {net.name} is secondary fixpoint")
|
||||
|
||||
def evolve(self, population):
|
||||
print(f"Clone soup has a population of {len(population)} networks")
|
||||
|
||||
loop_epochs = tqdm(range(self.epochs - 1))
|
||||
for i in loop_epochs:
|
||||
loop_epochs.set_description("\nEvolving clone soup %s" % i)
|
||||
|
||||
# A network attacking another network with a given percentage
|
||||
if random.randint(1, 100) <= self.attack_chance:
|
||||
random_net1, random_net2 = random.sample(range(len(population)), 2)
|
||||
random_net1 = population[random_net1]
|
||||
random_net2 = population[random_net2]
|
||||
print(f"\n Attack: {random_net1.name} -> {random_net2.name}")
|
||||
random_net1.attack(random_net2)
|
||||
|
||||
# Self-training each network in the population
|
||||
for j in range(len(population)):
|
||||
net = population[j]
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
def spawn_and_continue(self, number_clones: int = None):
|
||||
number_clones = number_clones or self.nr_clones
|
||||
|
||||
df = pd.DataFrame(
|
||||
columns=['name', 'parent', 'MAE_pre', 'MAE_post', 'MSE_pre', 'MSE_post', 'MIM_pre', 'MIM_post', 'noise',
|
||||
'status_post'])
|
||||
|
||||
# MAE_pre, MSE_pre, MIM_pre = 0, 0, 0
|
||||
|
||||
# For every initial net {i} after populating (that is fixpoint after first epoch);
|
||||
for i in range(len(self.parents)):
|
||||
net = self.parents[i]
|
||||
# We set parent start_time to just before this epoch ended, so plotting is zoomed in. Comment out to
|
||||
# to see full trajectory (but the clones will be very hard to see).
|
||||
# Make one target to compare distances to clones later when they have trained.
|
||||
net.start_time = self.ST_steps - 150
|
||||
net_input_data = net.input_weight_matrix()
|
||||
net_target_data = net.create_target_weights(net_input_data)
|
||||
|
||||
# print(f"\nNet {i} is fixpoint")
|
||||
|
||||
# Clone the fixpoint x times and add (+-)self.noise to weight-sets randomly;
|
||||
# To plot clones starting after first epoch (z=ST_steps), set that as start_time!
|
||||
# To make sure PCA will plot the same trajectory up until this point, we clone the
|
||||
# parent-net's weight history as well.
|
||||
for j in range(number_clones):
|
||||
clone = Net(net.input_size, net.hidden_size, net.out_size,
|
||||
f"net_{str(i)}_clone_{str(j)}", start_time=self.ST_steps)
|
||||
clone.load_state_dict(copy.deepcopy(net.state_dict()))
|
||||
clone = clone.apply_noise(self.noise)
|
||||
clone.s_train_weights_history = copy.deepcopy(net.s_train_weights_history)
|
||||
clone.number_trained = copy.deepcopy(net.number_trained)
|
||||
|
||||
# Pre Training distances (after noise application of course)
|
||||
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
MAE_pre = MAE(net_target_data, clone_pre_weights)
|
||||
MSE_pre = MSE(net_target_data, clone_pre_weights)
|
||||
MIM_pre = mean_invariate_manhattan_distance(net_target_data, clone_pre_weights)
|
||||
|
||||
df.loc[len(df)] = [clone.name, net.name, MAE_pre, 0, MSE_pre, 0, MIM_pre, 0, self.noise, ""]
|
||||
|
||||
net.child_nets.append(clone)
|
||||
self.clones.append(clone)
|
||||
self.parents_with_clones.append(clone)
|
||||
|
||||
self.evolve(self.clones)
|
||||
# evolve also with the parents together
|
||||
# self.evolve(self.parents_with_clones)
|
||||
|
||||
for i in range(len(self.parents)):
|
||||
net = self.parents[i]
|
||||
net_input_data = net.input_weight_matrix()
|
||||
net_target_data = net.create_target_weights(net_input_data)
|
||||
|
||||
for j in range(len(net.child_nets)):
|
||||
clone = net.child_nets[j]
|
||||
|
||||
# Post Training distances for comparison
|
||||
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
MAE_post = MAE(net_target_data, clone_post_weights)
|
||||
MSE_post = MSE(net_target_data, clone_post_weights)
|
||||
MIM_post = mean_invariate_manhattan_distance(net_target_data, clone_post_weights)
|
||||
|
||||
# .. log to data-frame and add to nets for 3d plotting if they are fixpoints themselves.
|
||||
test_status(clone)
|
||||
if is_identity_function(clone):
|
||||
print(f"Clone {j} (of net_{i}) is fixpoint."
|
||||
f"\nMSE({i},{j}): {MSE_post}"
|
||||
f"\nMAE({i},{j}): {MAE_post}"
|
||||
f"\nMIM({i},{j}): {MIM_post}\n")
|
||||
self.parents_clones_id_functions.append(clone)
|
||||
|
||||
# df.loc[df.name == clone.name, ["MAE_post", "MSE_post", "MIM_post"]] = [MAE_pre, MSE_pre, MIM_pre]
|
||||
|
||||
df.loc[df.name == clone.name, ["MAE_post", "MSE_post", "MIM_post", "status_post"]] = [MAE_post,
|
||||
MSE_post,
|
||||
MIM_post,
|
||||
clone.is_fixpoint]
|
||||
|
||||
# Finally take parent net {i} and finish it's training for comparison to clone development.
|
||||
for _ in range(self.epochs - 1):
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
net_weights_after = net.create_target_weights(net.input_weight_matrix())
|
||||
print(f"Parent net's distance to original position."
|
||||
f"\nMSE(OG,new): {MAE(net_target_data, net_weights_after)}"
|
||||
f"\nMAE(OG,new): {MSE(net_target_data, net_weights_after)}"
|
||||
f"\nMIM(OG,new): {mean_invariate_manhattan_distance(net_target_data, net_weights_after)}\n")
|
||||
|
||||
self.df = df
|
||||
|
||||
def weights_evolution_3d_experiment(self, nets_population, suffix):
|
||||
exp_name = f"soup_basins_{str(len(nets_population))}_nets_3d_weights_PCA_{suffix}"
|
||||
return plot_3d_soup(nets_population, exp_name, self.directory)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.parents)):
|
||||
net_loss_history = self.parents[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
plot_loss(self.loss_history, self.directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
NET_INPUT_SIZE = 4
|
||||
NET_OUT_SIZE = 1
|
||||
|
||||
# Define number of runs & name:
|
||||
ST_runs = 3
|
||||
ST_runs_name = "test-27"
|
||||
soup_ST_steps = 1500
|
||||
soup_epochs = 2
|
||||
soup_log_step_size = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
nr_clones = 5
|
||||
soup_population_size = 3
|
||||
soup_net_hidden_size = 2
|
||||
soup_net_learning_rate = 0.04
|
||||
soup_attack_chance = 10
|
||||
soup_name_hash = random.getrandbits(32)
|
||||
|
||||
print(f"Running the Soup-Spawn experiment:")
|
||||
exp_list = []
|
||||
for noise_factor in range(2, 5):
|
||||
exp = SoupSpawnExperiment(
|
||||
population_size=soup_population_size,
|
||||
log_step_size=soup_log_step_size,
|
||||
net_input_size=NET_INPUT_SIZE,
|
||||
net_hidden_size=soup_net_hidden_size,
|
||||
net_out_size=NET_OUT_SIZE,
|
||||
net_learning_rate=soup_net_learning_rate,
|
||||
epochs=soup_epochs,
|
||||
st_steps=soup_ST_steps,
|
||||
attack_chance=soup_attack_chance,
|
||||
nr_clones=nr_clones,
|
||||
noise=pow(10, -noise_factor),
|
||||
directory=Path('output') / 'soup_spawn_basin' / f'{soup_name_hash}' / f'10e-{noise_factor}'
|
||||
)
|
||||
exp_list.append(exp)
|
||||
|
||||
directory = Path('output') / 'soup_spawn_basin' / f'{soup_name_hash}'
|
||||
pickle.dump(exp_list, open(f"{directory}/experiment_pickle_{soup_name_hash}.p", "wb"))
|
||||
print(f"\nSaved experiment to {directory}.")
|
||||
|
||||
# Concat all dataframes, and add columns depending on where clone weights end up after training (rel. to parent)
|
||||
df = pd.concat([exp.df for exp in exp_list])
|
||||
df = df.dropna().reset_index()
|
||||
df["relative_distance"] = [ (df.loc[i]["MAE_pre"] - df.loc[i]["MAE_post"]) for i in range(len(df))]
|
||||
df["class"] = ["approaching" if df.loc[i]["relative_distance"] > 0 else "distancing" if df.loc[i]["relative_distance"] < 0 else "stationary" for i in range(len(df))]
|
||||
|
||||
# Countplot of all fixpoint clone after training per class. Uncomment and manually adjust xticklabels if x-ax size gets too small.
|
||||
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=12.7 / 5.27)
|
||||
ax.set_axis_labels("Noise Levels", "Clone Fixpoints After Training Count ", fontsize=15)
|
||||
# ax.set_xticklabels(labels=('10e-10', '10e-9', '10e-8', '10e-7', '10e-6', '10e-5', '10e-4', '10e-3', '10e-2', '10e-1'), fontsize=15)
|
||||
plt.savefig(f"{directory}/clone_status_after_countplot_{soup_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
# Catplot (either kind="point" or "box") that shows before-after training distances to parent
|
||||
mlt = df.melt(id_vars=["name", "noise", "class"], value_vars=["MAE_pre", "MAE_post"], var_name="State",
|
||||
value_name="Distance")
|
||||
P = ["blue" if mlt.loc[i]["class"] == "approaching" else "orange" if mlt.loc[i]["class"] == "distancing" else "green" for i in range(len(mlt))]
|
||||
# P = sns.color_palette(P, as_cmap=False)
|
||||
ax = sns.catplot(data=mlt, x="State", y="Distance", col="noise", hue="name", kind="point", palette=P,
|
||||
col_wrap=min(5, len(exp_list)), sharey=False, legend=False)
|
||||
ax.map(sns.boxplot, "State", "Distance", "noise", linewidth=0.8, order=["MAE_pre", "MAE_post"], whis=[0, 100])
|
||||
ax.set_axis_labels("", "Manhattan Distance To Parent Weights", fontsize=15)
|
||||
ax.set_xticklabels(labels=('after noise application', 'after training'), fontsize=15)
|
||||
plt.savefig(f"{directory}/before_after_distance_catplot_{soup_name_hash}.png")
|
||||
plt.clf()
|
@ -1,252 +0,0 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.ticker import ScalarFormatter
|
||||
from tqdm import tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.nn import functional as F
|
||||
from tabulate import tabulate
|
||||
|
||||
from functionalities_test import test_for_fixpoints, is_zero_fixpoint, is_divergent, is_identity_function
|
||||
from network import Net
|
||||
from visualization import plot_loss, bar_chart_fixpoints, plot_3d_soup, line_chart_fixpoints
|
||||
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
class SoupRobustnessExperiment:
|
||||
|
||||
def __init__(self, population_size, net_i_size, net_h_size, net_o_size, learning_rate, attack_chance,
|
||||
train_nets, ST_steps, epochs, log_step_size, directory: Union[str, Path]):
|
||||
super().__init__()
|
||||
self.population_size = population_size
|
||||
|
||||
self.net_input_size = net_i_size
|
||||
self.net_hidden_size = net_h_size
|
||||
self.net_out_size = net_o_size
|
||||
self.net_learning_rate = learning_rate
|
||||
self.attack_chance = attack_chance
|
||||
self.train_nets = train_nets
|
||||
# self.SA_steps = SA_steps
|
||||
self.ST_steps = ST_steps
|
||||
self.epochs = epochs
|
||||
self.log_step_size = log_step_size
|
||||
|
||||
self.loss_history = []
|
||||
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
# <self.fixpoint_counters_history> is used for keeping track of the amount of fixpoints in %
|
||||
self.fixpoint_counters_history = []
|
||||
self.id_functions = []
|
||||
|
||||
self.directory = Path(directory)
|
||||
self.directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.population = []
|
||||
self.populate_environment()
|
||||
|
||||
self.evolve()
|
||||
self.fixpoint_percentage()
|
||||
self.weights_evolution_3d_experiment()
|
||||
self.count_fixpoints()
|
||||
self.visualize_loss()
|
||||
|
||||
self.time_to_vergence, self.time_as_fixpoint = self.test_robustness()
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
for i in tqdm(range(self.population_size)):
|
||||
loop_population_size.set_description("Populating soup experiment %s" % i)
|
||||
|
||||
net_name = f"soup_network_{i}"
|
||||
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
|
||||
self.population.append(net)
|
||||
|
||||
def evolve(self):
|
||||
""" Evolving consists of attacking & self-training. """
|
||||
|
||||
loop_epochs = tqdm(range(self.epochs))
|
||||
for i in loop_epochs:
|
||||
loop_epochs.set_description("Evolving soup %s" % i)
|
||||
|
||||
# A network attacking another network with a given percentage
|
||||
if random.randint(1, 100) <= self.attack_chance:
|
||||
random_net1, random_net2 = random.sample(range(self.population_size), 2)
|
||||
random_net1 = self.population[random_net1]
|
||||
random_net2 = self.population[random_net2]
|
||||
print(f"\n Attack: {random_net1.name} -> {random_net2.name}")
|
||||
random_net1.attack(random_net2)
|
||||
|
||||
# Self-training each network in the population
|
||||
for j in range(self.population_size):
|
||||
net = self.population[j]
|
||||
|
||||
for _ in range(self.ST_steps):
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate)
|
||||
|
||||
# Testing for fixpoints after each batch of ST steps to see relevant data
|
||||
if i % self.ST_steps == 0:
|
||||
test_for_fixpoints(self.fixpoint_counters, self.population)
|
||||
fixpoints_percentage = round(self.fixpoint_counters["identity_func"] / self.population_size, 1)
|
||||
self.fixpoint_counters_history.append(fixpoints_percentage)
|
||||
|
||||
# Resetting the fixpoint counter. Last iteration not to be reset -
|
||||
# it is important for the bar_chart_fixpoints().
|
||||
if i < self.epochs:
|
||||
self.reset_fixpoint_counters()
|
||||
|
||||
def test_robustness(self, print_it=True, noise_levels=10, seeds=10):
|
||||
# assert (len(self.id_functions) == 1 and seeds > 1) or (len(self.id_functions) > 1 and seeds == 1)
|
||||
is_synthetic = True if len(self.id_functions) > 1 and seeds == 1 else False
|
||||
avg_time_to_vergence = [[0 for _ in range(noise_levels)] for _ in
|
||||
range(seeds if is_synthetic else len(self.id_functions))]
|
||||
avg_time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in
|
||||
range(seeds if is_synthetic else len(self.id_functions))]
|
||||
row_headers = []
|
||||
data_pos = 0
|
||||
# This checks wether to use synthetic setting with multiple seeds
|
||||
# or multi network settings with a singlee seed
|
||||
|
||||
df = pd.DataFrame(columns=['seed', 'noise_level', 'application_step', 'absolute_loss'])
|
||||
for i, fixpoint in enumerate(self.id_functions): # 1 / n
|
||||
row_headers.append(fixpoint.name)
|
||||
for seed in range(seeds): # n / 1
|
||||
for noise_level in range(noise_levels):
|
||||
self_application_steps = 1
|
||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||
f"{fixpoint.name}_clone_noise10e-{noise_level}")
|
||||
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||
clone = clone.apply_noise(pow(10, -noise_level))
|
||||
|
||||
while not is_zero_fixpoint(clone) and not is_divergent(clone):
|
||||
if is_identity_function(clone):
|
||||
avg_time_as_fixpoint[i][noise_level] += 1
|
||||
|
||||
# -> before
|
||||
clone_weight_pre_application = clone.input_weight_matrix()
|
||||
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
|
||||
|
||||
clone.self_application(1, self.log_step_size)
|
||||
avg_time_to_vergence[i][noise_level] += 1
|
||||
# -> after
|
||||
clone_weight_post_application = clone.input_weight_matrix()
|
||||
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
||||
|
||||
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
|
||||
|
||||
setting = i if is_synthetic else seed
|
||||
|
||||
df.loc[data_pos] = [setting, noise_level, self_application_steps, absolute_loss]
|
||||
data_pos += 1
|
||||
self_application_steps += 1
|
||||
|
||||
# calculate the average:
|
||||
df = df.replace([np.inf, -np.inf], np.nan)
|
||||
df = df.dropna()
|
||||
# sns.set(rc={'figure.figsize': (10, 50)})
|
||||
sns.set_theme(style="ticks")
|
||||
bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
|
||||
col='noise_level', col_wrap=3, showfliers=False)
|
||||
|
||||
directory = Path('output') / 'robustness'
|
||||
filename = f"absolute_loss_perapplication_boxplot_grid.png"
|
||||
filepath = directory / filename
|
||||
|
||||
plt.savefig(str(filepath))
|
||||
|
||||
if print_it:
|
||||
col_headers = [str(f"10-{d}") for d in range(noise_levels)]
|
||||
|
||||
print(f"\nAppplications steps until divergence / zero: ")
|
||||
print(tabulate(avg_time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||
|
||||
print(f"\nTime as fixpoint: ")
|
||||
print(tabulate(avg_time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||
|
||||
return avg_time_as_fixpoint, avg_time_to_vergence
|
||||
|
||||
def weights_evolution_3d_experiment(self):
|
||||
exp_name = f"soup_{self.population_size}_nets_{self.ST_steps}_training_{self.epochs}_epochs"
|
||||
return plot_3d_soup(self.population, exp_name, self.directory)
|
||||
|
||||
def count_fixpoints(self):
|
||||
self.id_functions = test_for_fixpoints(self.fixpoint_counters, self.population)
|
||||
exp_details = f"Evolution steps: {self.epochs} epochs"
|
||||
bar_chart_fixpoints(self.fixpoint_counters, self.population_size, self.directory, self.net_learning_rate,
|
||||
exp_details)
|
||||
|
||||
def fixpoint_percentage(self):
|
||||
runs = self.epochs / self.ST_steps
|
||||
SA_steps = None
|
||||
line_chart_fixpoints(self.fixpoint_counters_history, runs, self.ST_steps, SA_steps, self.directory,
|
||||
self.population_size)
|
||||
|
||||
def visualize_loss(self):
|
||||
for i in range(len(self.population)):
|
||||
net_loss_history = self.population[i].loss_history
|
||||
self.loss_history.append(net_loss_history)
|
||||
|
||||
plot_loss(self.loss_history, self.directory)
|
||||
|
||||
def reset_fixpoint_counters(self):
|
||||
self.fixpoint_counters = {
|
||||
"identity_func": 0,
|
||||
"divergent": 0,
|
||||
"fix_zero": 0,
|
||||
"fix_weak": 0,
|
||||
"fix_sec": 0,
|
||||
"other_func": 0
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NET_INPUT_SIZE = 4
|
||||
NET_OUT_SIZE = 1
|
||||
|
||||
soup_epochs = 100
|
||||
soup_log_step_size = 5
|
||||
soup_ST_steps = 20
|
||||
# soup_SA_steps = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
soup_population_size = 4
|
||||
soup_net_hidden_size = 2
|
||||
soup_net_learning_rate = 0.04
|
||||
|
||||
# soup_attack_chance in %
|
||||
soup_attack_chance = 10
|
||||
|
||||
# not used yet: soup_train_nets has 3 possible values "no", "before_SA", "after_SA".
|
||||
soup_train_nets = "no"
|
||||
soup_name_hash = random.getrandbits(32)
|
||||
soup_synthetic = True
|
||||
|
||||
print(f"Running the robustness comparison experiment:")
|
||||
SoupRobustnessExperiment(
|
||||
population_size=soup_population_size,
|
||||
net_i_size=NET_INPUT_SIZE,
|
||||
net_h_size=soup_net_hidden_size,
|
||||
net_o_size=NET_OUT_SIZE,
|
||||
learning_rate=soup_net_learning_rate,
|
||||
attack_chance=soup_attack_chance,
|
||||
train_nets=soup_train_nets,
|
||||
ST_steps=soup_ST_steps,
|
||||
epochs=soup_epochs,
|
||||
log_step_size=soup_log_step_size,
|
||||
directory=Path('output') / 'robustness' / f'{soup_name_hash}'
|
||||
)
|
150
main.py
@ -1,150 +0,0 @@
|
||||
from experiments import *
|
||||
import random
|
||||
|
||||
|
||||
# TODO maybe add also SA to the soup
|
||||
|
||||
def run_experiments(run_ST, run_SA, run_soup, run_mixed, run_robustness):
|
||||
if run_ST:
|
||||
print(f"Running the ST experiment:")
|
||||
run_ST_experiment(ST_population_size, ST_log_step_size, NET_INPUT_SIZE, ST_net_hidden_size, NET_OUT_SIZE,
|
||||
ST_net_learning_rate,
|
||||
ST_epochs, ST_runs, ST_runs_name, ST_name_hash)
|
||||
if run_SA:
|
||||
print(f"\n Running the SA experiment:")
|
||||
run_SA_experiment(SA_population_size, SA_log_step_size, NET_INPUT_SIZE, SA_net_hidden_size, NET_OUT_SIZE,
|
||||
SA_net_learning_rate, SA_runs, SA_runs_name, SA_name_hash,
|
||||
SA_steps, SA_train_nets, SA_ST_steps)
|
||||
if run_soup:
|
||||
print(f"\n Running the soup experiment:")
|
||||
run_soup_experiment(soup_population_size, soup_attack_chance, NET_INPUT_SIZE, soup_net_hidden_size,
|
||||
NET_OUT_SIZE, soup_net_learning_rate, soup_epochs, soup_log_step_size, soup_runs,
|
||||
soup_runs_name, soup_name_hash, soup_ST_steps, soup_train_nets)
|
||||
if run_mixed:
|
||||
print(f"\n Running the mixed experiment:")
|
||||
run_mixed_experiment(mixed_population_size, NET_INPUT_SIZE, mixed_net_hidden_size, NET_OUT_SIZE,
|
||||
mixed_net_learning_rate, mixed_train_nets, mixed_epochs, mixed_SA_steps,
|
||||
mixed_ST_steps_between_SA, mixed_log_step_size, mixed_name_hash, mixed_total_runs,
|
||||
mixed_runs_name)
|
||||
if run_robustness:
|
||||
print(f"Running the robustness experiment:")
|
||||
run_robustness_experiment(rob_population_size, rob_log_step_size, NET_INPUT_SIZE, rob_net_hidden_size,
|
||||
NET_OUT_SIZE, rob_net_learning_rate, rob_ST_steps, rob_runs, rob_runs_name,
|
||||
rob_name_hash)
|
||||
|
||||
if not run_ST and not run_SA and not run_soup and not run_mixed and not run_robustness:
|
||||
print(f"No experiments to be run.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Constants:
|
||||
NET_INPUT_SIZE = 4
|
||||
NET_OUT_SIZE = 1
|
||||
run_ST_experiment_bool = False
|
||||
run_SA_experiment_bool = False
|
||||
run_soup_experiment_bool = False
|
||||
run_mixed_experiment_bool = False
|
||||
run_robustness_bool = True
|
||||
|
||||
""" ------------------------------------- Self-training (ST) experiment ------------------------------------- """
|
||||
|
||||
# Define number of runs & name:
|
||||
ST_runs = 1
|
||||
ST_runs_name = "test-27"
|
||||
ST_epochs = 1000
|
||||
ST_log_step_size = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
ST_population_size = 1
|
||||
ST_net_hidden_size = 2
|
||||
|
||||
ST_net_learning_rate = 0.04
|
||||
|
||||
ST_name_hash = random.getrandbits(32)
|
||||
|
||||
""" ----------------------------------- Self-application (SA) experiment ----------------------------------- """
|
||||
# Define number of runs, name, etc.:
|
||||
SA_runs_name = "test-17"
|
||||
SA_runs = 2
|
||||
SA_steps = 100
|
||||
SA_app_batch_size = 5
|
||||
SA_train_batch_size = 5
|
||||
SA_log_step_size = 5
|
||||
|
||||
# Define number of networks & their architecture
|
||||
SA_population_size = 10
|
||||
SA_net_hidden_size = 2
|
||||
|
||||
SA_net_learning_rate = 0.04
|
||||
|
||||
# SA_train_nets has 3 possible values "no", "before_SA", "after_SA".
|
||||
SA_train_nets = "no"
|
||||
SA_ST_steps = 300
|
||||
|
||||
SA_name_hash = random.getrandbits(32)
|
||||
|
||||
""" -------------------------------------------- Soup experiment -------------------------------------------- """
|
||||
# Define number of runs, name, etc.:
|
||||
soup_runs = 1
|
||||
soup_runs_name = "test-16"
|
||||
soup_epochs = 100
|
||||
soup_log_step_size = 5
|
||||
soup_ST_steps = 20
|
||||
# soup_SA_steps = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
soup_population_size = 5
|
||||
soup_net_hidden_size = 2
|
||||
soup_net_learning_rate = 0.04
|
||||
|
||||
# soup_attack_chance in %
|
||||
soup_attack_chance = 10
|
||||
|
||||
# not used yet: soup_train_nets has 3 possible values "no", "before_SA", "after_SA".
|
||||
soup_train_nets = "no"
|
||||
|
||||
soup_name_hash = random.getrandbits(32)
|
||||
|
||||
""" ------------------------------------------- Mixed experiment -------------------------------------------- """
|
||||
|
||||
# Define number of runs, name, etc.:
|
||||
mixed_runs_name = "test-17"
|
||||
mixed_total_runs = 2
|
||||
|
||||
# Define number of networks & their architecture
|
||||
mixed_population_size = 5
|
||||
mixed_net_hidden_size = 2
|
||||
|
||||
mixed_epochs = 10
|
||||
# Set the <batch_size> to the same value as <ST_steps_between_SA> to see the weights plotted
|
||||
# ONLY after each epoch, and not after a certain amount of steps.
|
||||
mixed_log_step_size = 5
|
||||
mixed_ST_steps_between_SA = 50
|
||||
mixed_SA_steps = 4
|
||||
|
||||
mixed_net_learning_rate = 0.04
|
||||
|
||||
# mixed_train_nets has 2 possible values "before_SA", "after_SA".
|
||||
mixed_train_nets = "after_SA"
|
||||
|
||||
mixed_name_hash = random.getrandbits(32)
|
||||
|
||||
""" ----------------------------------------- Robustness experiment ----------------------------------------- """
|
||||
# Define number of runs & name:
|
||||
rob_runs = 1
|
||||
rob_runs_name = "test-07"
|
||||
rob_ST_steps = 1500
|
||||
rob_log_step_size = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
rob_population_size = 1
|
||||
rob_net_hidden_size = 2
|
||||
|
||||
rob_net_learning_rate = 0.04
|
||||
|
||||
rob_name_hash = random.getrandbits(32)
|
||||
|
||||
""" ---------------------------------------- Running the experiment ----------------------------------------- """
|
||||
|
||||
run_experiments(run_ST_experiment_bool, run_SA_experiment_bool, run_soup_experiment_bool, run_mixed_experiment_bool,
|
||||
run_robustness_bool)
|
273
meta_task_exp.py
Normal file
@ -0,0 +1,273 @@
|
||||
# # # Imports
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import platform
|
||||
|
||||
import torchmetrics
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
from experiments.meta_task_utility import (ToFloat, new_storage_df, train_task, checkpoint_and_validate, flat_for_store,
|
||||
plot_training_result, plot_training_particle_types,
|
||||
plot_network_connectivity_by_fixtype, run_particle_dropout_and_plot,
|
||||
highlight_fixpoints_vs_mnist_mean, AddGaussianNoise,
|
||||
plot_training_results_over_n_seeds, sanity_weight_swap,
|
||||
FINAL_CHECKPOINT_NAME)
|
||||
from experiments.robustness_tester import test_robustness
|
||||
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
|
||||
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
|
||||
|
||||
from network import MetaNet, FixTypes
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) # , AddGaussianNoise()])
|
||||
WORKER = 10 if not debug else 2
|
||||
debug = False
|
||||
BATCHSIZE = 2000 if not debug else 50
|
||||
EPOCH = 200
|
||||
VALIDATION_FRQ = 3 if not debug else 1
|
||||
VAL_METRIC_CLASS = torchmetrics.Accuracy
|
||||
# noinspection PyProtectedMember
|
||||
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
DATA_PATH = Path('data')
|
||||
DATA_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
try:
|
||||
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False, download=True)
|
||||
plot_loader = DataLoader(plot_dataset, batch_size=BATCHSIZE, shuffle=False,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
training = True
|
||||
plotting = True
|
||||
robustnes = False
|
||||
n_st = 300 # per batch !!
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
train_to_task_first = True
|
||||
min_task_acc = 0.85
|
||||
|
||||
residual_skip = True
|
||||
add_gauss = False
|
||||
|
||||
alpha_st_modulator = 0
|
||||
|
||||
for weight_hidden_size in [5]:
|
||||
|
||||
weight_hidden_size = weight_hidden_size
|
||||
n_seeds = 3
|
||||
depth = 3
|
||||
width = 5
|
||||
out = 10
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
|
||||
a_str = f'_aStM_{alpha_st_modulator}' if alpha_st_modulator not in [1, 0] else ''
|
||||
res_str = '_no_res' if not residual_skip else ''
|
||||
st_str = f'_nst_{n_st}'
|
||||
tsk_str = f'_tsktr_{min_task_acc}' if train_to_task_first else ''
|
||||
w_str = f'_w{width}wh{weight_hidden_size}d{depth}'
|
||||
|
||||
config_str = f'{res_str}{ac_str}{st_str}{tsk_str}{a_str}{w_str}'
|
||||
exp_path = Path('output') / f'mn_st_{EPOCH}{config_str}'
|
||||
last_accuracy = 0
|
||||
|
||||
for seed in range(0, n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
|
||||
df_store_path = seed_path / 'train_store.csv'
|
||||
weight_store_path = seed_path / 'weight_store.csv'
|
||||
srnn_parameters = dict()
|
||||
|
||||
if training:
|
||||
# Check if files do exist on project location, warn and break.
|
||||
for path in [df_store_path, weight_store_path]:
|
||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||
try:
|
||||
train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms)
|
||||
except RuntimeError:
|
||||
train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, download=True)
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
try:
|
||||
valid_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
valid_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False, download=True)
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(train_dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=depth, width=width, out=out,
|
||||
residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,
|
||||
activation=activation
|
||||
).to(DEVICE)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
|
||||
metanet = metanet.train()
|
||||
|
||||
# Init metrics, even we do not need:
|
||||
metric = VAL_METRIC_CLASS()
|
||||
n_st_per_batch = max(n_st // len(train_loader), 1)
|
||||
|
||||
if is_validation_epoch:
|
||||
for particle in metanet.particles:
|
||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
|
||||
do_self_train = not train_to_task_first or last_accuracy >= min_task_acc
|
||||
train_to_task_first = train_to_task_first if not do_self_train else False
|
||||
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(train_loader),
|
||||
total=len(train_loader), desc='MetaNet Train - Batch'
|
||||
):
|
||||
# Self Train
|
||||
if do_self_train:
|
||||
self_train_loss = metanet.combined_self_train(n_st_per_batch, alpha=alpha_st_modulator,
|
||||
reduction='mean', per_particle=False)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
st_step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Task Train
|
||||
tsk_step_log, y_pred = train_task(metanet, optimizer, loss_fn, batch_x, batch_y)
|
||||
tsk_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = tsk_step_log
|
||||
metric(y_pred.cpu(), batch_y.cpu())
|
||||
|
||||
last_accuracy = metric.compute().item()
|
||||
if is_validation_epoch:
|
||||
metanet = metanet.eval()
|
||||
try:
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Train {VAL_METRIC_NAME}', Score=last_accuracy)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch, keep_n=5).item()
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
if is_validation_epoch:
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
counter_dict = dict(counter_dict)
|
||||
for key, value in counter_dict.items():
|
||||
val_step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = val_step_log
|
||||
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
|
||||
|
||||
# FLUSH to disk
|
||||
if is_validation_epoch:
|
||||
train_store.to_csv(df_store_path, mode='a',
|
||||
header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a',
|
||||
header=not weight_store_path.exists(), index=False)
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
###########################################################
|
||||
# EPOCHS endet
|
||||
metanet = metanet.eval()
|
||||
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
for key, value in dict(counter_dict).items():
|
||||
step_log = dict(Epoch=int(EPOCH)+1, Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, EPOCH, final_model=True)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
for particle in metanet.particles:
|
||||
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
|
||||
# FLUSH to disk
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
||||
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'{FINAL_CHECKPOINT_NAME}'))
|
||||
except StopIteration:
|
||||
print('Model pattern did not trigger.')
|
||||
print(f'Search path was: {seed_path}:')
|
||||
print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
|
||||
exit(1)
|
||||
|
||||
if plotting:
|
||||
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
|
||||
|
||||
plot_training_result(df_store_path)
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
try:
|
||||
plot_network_connectivity_by_fixtype(model_path)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
|
||||
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
|
||||
|
||||
try:
|
||||
for fixtype in FixTypes.all_types():
|
||||
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
|
||||
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
if robustnes:
|
||||
try:
|
||||
test_robustness(model_path, seeds=1)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
|
||||
if 2 <= n_seeds <= sum(list(x.is_dir() for x in exp_path.iterdir())):
|
||||
if plotting:
|
||||
|
||||
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
|
||||
|
||||
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)
|
244
meta_task_exp_small.py
Normal file
@ -0,0 +1,244 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.meta_task_small_utility import AddTaskDataset, train_task
|
||||
from experiments.robustness_tester import test_robustness
|
||||
from network import MetaNet
|
||||
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
||||
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
|
||||
plot_training_particle_types, run_particle_dropout_and_plot, plot_network_connectivity_by_fixtype, \
|
||||
checkpoint_and_validate, plot_training_results_over_n_seeds, sanity_weight_swap, FINAL_CHECKPOINT_NAME
|
||||
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
|
||||
|
||||
WORKER = 0
|
||||
BATCHSIZE = 50
|
||||
EPOCH = 60
|
||||
VALIDATION_FRQ = 3
|
||||
VAL_METRIC_CLASS = torchmetrics.MeanAbsoluteError
|
||||
# noinspection PyProtectedMember
|
||||
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
plot_loader = DataLoader(AddTaskDataset(), batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
training = False
|
||||
plotting = True
|
||||
robustness = True
|
||||
attack = False
|
||||
attack_ratio = 0.01
|
||||
melt = False
|
||||
melt_ratio = 0.01
|
||||
n_st = 200
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
for weight_hidden_size in [3]:
|
||||
|
||||
weight_hidden_size = weight_hidden_size
|
||||
residual_skip = True
|
||||
n_seeds = 10
|
||||
depth = 3
|
||||
width = 3
|
||||
out = 1
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
|
||||
res_str = f'{"" if residual_skip else "_no_res"}'
|
||||
att_str = f'_att_{attack_ratio}' if attack else ''
|
||||
mlt_str = f'_mlt_{melt_ratio}' if melt else ''
|
||||
w_str = f'_w{width}wh{weight_hidden_size}d{depth}'
|
||||
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
||||
|
||||
config_str = f'{res_str}{att_str}{ac_str}{mlt_str}{w_str}'
|
||||
exp_path = Path('output') / f'add_st_{EPOCH}{config_str}'
|
||||
|
||||
# if not training:
|
||||
# # noinspection PyRedeclaration
|
||||
# exp_path = Path('output') / f'add_st_{n_st}_{weight_hidden_size}'
|
||||
|
||||
for seed in range(n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
|
||||
df_store_path = seed_path / 'train_store.csv'
|
||||
weight_store_path = seed_path / 'weight_store.csv'
|
||||
srnn_parameters = dict()
|
||||
|
||||
valid_data = AddTaskDataset()
|
||||
vali_load = DataLoader(valid_data, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
if training:
|
||||
# Check if files do exist on project location, warn and break.
|
||||
for path in [df_store_path, weight_store_path]:
|
||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||
|
||||
train_data = AddTaskDataset()
|
||||
train_load = DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(train_data[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=depth, width=width, out=out,
|
||||
residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,
|
||||
activation=activation
|
||||
).to(DEVICE)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0
|
||||
metanet = metanet.train()
|
||||
|
||||
# Init metrics, even we do not need:
|
||||
metric = VAL_METRIC_CLASS()
|
||||
n_st_per_batch = max(1, (n_st // len(train_load)))
|
||||
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(train_load),
|
||||
total=len(train_load), desc='MetaNet Train - Batch'
|
||||
):
|
||||
# Self Train
|
||||
self_train_loss = metanet.combined_self_train(n_st_per_batch,
|
||||
reduction='mean', per_particle=False)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
st_step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Attack
|
||||
if attack:
|
||||
after_attack_loss = metanet.make_particles_attack(attack_ratio)
|
||||
st_step_log = dict(Metric='After Attack Loss', Score=after_attack_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Melt
|
||||
if melt:
|
||||
after_melt_loss = metanet.make_particles_melt(melt_ratio)
|
||||
st_step_log = dict(Metric='After Melt Loss', Score=after_melt_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Task Train
|
||||
tsk_step_log, y_pred = train_task(metanet, optimizer, loss_fn, batch_x, batch_y)
|
||||
tsk_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = tsk_step_log
|
||||
metric(y_pred.cpu(), batch_y.cpu())
|
||||
|
||||
if is_validation_epoch:
|
||||
metanet = metanet.eval()
|
||||
if metric.total.item():
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Train {VAL_METRIC_NAME}', Score=metric.compute().item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
mae = checkpoint_and_validate(metanet, vali_load, seed_path, epoch,
|
||||
validation_metric=VAL_METRIC_CLASS).item()
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=mae)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
if is_validation_epoch:
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
counter_dict = dict(counter_dict)
|
||||
for key, value in counter_dict.items():
|
||||
val_step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = val_step_log
|
||||
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
|
||||
|
||||
# FLUSH to disk
|
||||
if is_validation_epoch:
|
||||
for particle in metanet.particles:
|
||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
train_store.to_csv(df_store_path, mode='a',
|
||||
header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a',
|
||||
header=not weight_store_path.exists(), index=False)
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
###########################################################
|
||||
# EPOCHS endet
|
||||
metanet = metanet.eval()
|
||||
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
for key, value in dict(counter_dict).items():
|
||||
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
accuracy = checkpoint_and_validate(metanet, vali_load, seed_path, EPOCH, final_model=True,
|
||||
validation_metric=VAL_METRIC_CLASS)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
|
||||
for particle in metanet.particles:
|
||||
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
||||
if plotting:
|
||||
|
||||
plot_training_result(df_store_path, metric_name=VAL_METRIC_NAME)
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
|
||||
except StopIteration:
|
||||
print('####################################################')
|
||||
print('ERROR: Model pattern did not trigger.')
|
||||
print(f'INFO: Search path was: {seed_path}:')
|
||||
print(f'INFO: Found Models are: {list(seed_path.rglob(".tp"))}')
|
||||
print('####################################################')
|
||||
exit(1)
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
try:
|
||||
plot_network_connectivity_by_fixtype(model_path)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
try:
|
||||
tqdm.write('Trajectory plotting ...')
|
||||
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
|
||||
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
|
||||
tqdm.write('Trajectory plotting Done')
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
if robustness:
|
||||
try:
|
||||
test_robustness(model_path, seeds=10)
|
||||
pass
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
|
||||
if 2 <= n_seeds == sum(list(x.is_dir() for x in exp_path.iterdir())):
|
||||
if plotting:
|
||||
|
||||
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
|
||||
|
||||
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)
|
105
meta_task_sanity_exp.py
Normal file
@ -0,0 +1,105 @@
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import platform
|
||||
|
||||
import pandas as pd
|
||||
import torch.optim
|
||||
from matplotlib import pyplot as plt
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from tqdm import trange, tqdm
|
||||
from tqdm.contrib import tenumerate
|
||||
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
import functionalities_test
|
||||
from network import Net
|
||||
|
||||
|
||||
class MultiplyByXTaskDataset(Dataset):
|
||||
def __init__(self, x=0.23, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.x = x
|
||||
self.prng = np.random.default_rng()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = self.prng.normal(size=(1,)).astype(np.float32)
|
||||
return ab, ab * self.x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Net(5, 4, 1, lr=0.004)
|
||||
multiplication_target = 0.03
|
||||
st_steps = 0
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
|
||||
dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000)
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=8000, num_workers=0)
|
||||
for epoch in trange(30):
|
||||
mean_batch_loss = []
|
||||
mean_self_tain_loss = []
|
||||
|
||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||
self_train_loss, _ = net.self_train(1000 // 20, save_history=False)
|
||||
is_fixpoint = functionalities_test.is_identity_function(net)
|
||||
if not is_fixpoint:
|
||||
st_steps += 2
|
||||
|
||||
if is_fixpoint:
|
||||
tqdm.write(f'is fixpoint after st : {is_fixpoint}, first reached after st_steps: {st_steps}')
|
||||
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_identity_function(net)}')
|
||||
|
||||
#mean_batch_loss.append(loss.detach())
|
||||
mean_self_tain_loss.append(self_train_loss.detach())
|
||||
|
||||
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='Self Train Loss', Score=np.average(mean_self_tain_loss))
|
||||
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='Batch Loss', Score=np.average(mean_batch_loss))
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
functionalities_test.test_for_fixpoints(counter, nets=[net])
|
||||
print(dict(counter), self_train_loss)
|
||||
sanity = net(torch.Tensor([0,0,0,0,1])).detach()
|
||||
print(sanity)
|
||||
print(abs(sanity - multiplication_target))
|
||||
sns.lineplot(data=train_frame, x='Epoch', y='Score', hue='Metric')
|
||||
outpath = Path('output') / 'sanity' / 'test.png'
|
||||
outpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
plt.savefig(outpath)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
85
minimal_net_search.py
Normal file
@ -0,0 +1,85 @@
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST, CIFAR10
|
||||
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
|
||||
import torchmetrics
|
||||
import pickle
|
||||
|
||||
from network import MetaNetCompareBaseline
|
||||
|
||||
WORKER = 0
|
||||
BATCHSIZE = 500
|
||||
EPOCH = 10
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
MNIST_TRANSFORM = Compose([ Resize((10, 10)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)])
|
||||
CIFAR10_TRANSFORM = Compose([ Grayscale(num_output_channels=1), Resize((10, 10)), ToTensor(), Normalize((0.48,), (0.25,)), Flatten(start_dim=0)])
|
||||
|
||||
|
||||
def train_and_test(testnet, optimizer, loss, trainset, testset):
|
||||
d_train = DataLoader(trainset, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
d_test = DataLoader(testset, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
|
||||
# train
|
||||
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epoch'):
|
||||
for batch, (batch_x, batch_y) in enumerate(d_train):
|
||||
optimizer.zero_grad()
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = testnet(batch_x)
|
||||
loss = loss_fn(y, batch_y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# test
|
||||
testnet.eval()
|
||||
metric = torchmetrics.Accuracy()
|
||||
with tqdm(desc='Test Batch: ') as pbar:
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d_test), total=len(d_test), desc='MetaNet Test - Batch'):
|
||||
y = testnet(batch_x)
|
||||
loss = loss_fn(y, batch_y)
|
||||
acc = metric(y.cpu(), batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
|
||||
acc = metric.compute()
|
||||
tqdm.write(f"Avg. accuracy on all data: {acc}")
|
||||
return acc
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(42)
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
mnist_train = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=True)
|
||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||
cifar10_train = CIFAR10(str(data_path), transform=CIFAR10_TRANSFORM, download=True, train=True)
|
||||
cifar10_test = CIFAR10(str(data_path), transform=CIFAR10_TRANSFORM, download=True, train=False)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
frame = pd.DataFrame(columns=['Dataset', 'Neurons', 'Layers', 'Parameters', 'Accuracy'])
|
||||
|
||||
for name, trainset, testset in [("MNIST",mnist_train,mnist_test), ("CIFAR10",cifar10_train,cifar10_test)]:
|
||||
best_acc = 0
|
||||
neuron_count = 0
|
||||
layer_count = 0
|
||||
|
||||
# find upper bound (in steps of 10, neurons/layer > 200 will start back from 10 with layers+1)
|
||||
while best_acc <= 0.95:
|
||||
neuron_count += 10
|
||||
if neuron_count >= 210:
|
||||
neuron_count = 10
|
||||
layer_count += 1
|
||||
net = MetaNetCompareBaseline(100, layer_count, neuron_count, out=10)
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||
acc = train_and_test(net, optimizer, loss_fn, trainset, testset)
|
||||
if acc > best_acc:
|
||||
best_acc = acc
|
||||
|
||||
num_params = sum(p.numel() for p in net._meta_layer_list.parameters())
|
||||
frame.loc[frame.shape[0]] = dict(Dataset=name, Neurons=neuron_count, Layers=layer_count, Parameters=num_params, Accuracy=acc)
|
||||
print(f"> {name}\t| {neuron_count} neurons\t| {layer_count} h.-layer(s)\t| {num_params} params\n")
|
||||
|
||||
print(frame)
|
||||
pickle.dump(frame, "min_net_search_df.pkl")
|
353
network.py
@ -1,7 +1,6 @@
|
||||
# from __future__ import annotations
|
||||
import copy
|
||||
import random
|
||||
from math import sqrt
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
@ -11,11 +10,35 @@ import torch.nn.functional as F
|
||||
from torch import optim, Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
def xavier_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight.data)
|
||||
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
class FixTypes:
|
||||
|
||||
divergent = 'Divergend'
|
||||
fix_zero = 'All Zero'
|
||||
identity_func = 'Self-Replicator'
|
||||
fix_sec = 'Self-Replicator 2nd'
|
||||
other_func = 'Other'
|
||||
|
||||
@classmethod
|
||||
def all_types(cls):
|
||||
return [val for key, val in cls.__dict__.items() if isinstance(val, str) and not key.startswith('_')]
|
||||
|
||||
|
||||
class NetworkLevel:
|
||||
|
||||
all = 'All'
|
||||
layer = 'Layer'
|
||||
cell = 'Cell'
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
@ -28,7 +51,8 @@ class Net(nn.Module):
|
||||
# target_weight_matrix[i] = input_weight_matrix[i][0]
|
||||
|
||||
# Fast and simple
|
||||
return input_weight_matrix[:, 0].unsqueeze(-1)
|
||||
target_weights = input_weight_matrix[:, 0].detach().unsqueeze(-1)
|
||||
return target_weights
|
||||
|
||||
|
||||
@staticmethod
|
||||
@ -49,17 +73,15 @@ class Net(nn.Module):
|
||||
|
||||
def apply_weights(self, new_weights: Tensor):
|
||||
""" Changing the weights of a network to new given values. """
|
||||
# TODO: Change this to 'parameters' version
|
||||
with torch.no_grad():
|
||||
i = 0
|
||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
|
||||
self.state_dict()[layer_name][line_id][weight_id] = new_weights[i]
|
||||
i += 1
|
||||
|
||||
for parameters in self.parameters():
|
||||
size = parameters.numel()
|
||||
parameters[:] = new_weights[i:i+size].view(parameters.shape)[:]
|
||||
i += size
|
||||
return self
|
||||
|
||||
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1) -> None:
|
||||
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1, lr=0.004) -> None:
|
||||
super().__init__()
|
||||
self.start_time = start_time
|
||||
|
||||
@ -79,7 +101,7 @@ class Net(nn.Module):
|
||||
self.trained = False
|
||||
self.number_trained = 0
|
||||
|
||||
self.is_fixpoint = ""
|
||||
self.is_fixpoint = FixTypes.other_func
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Linear(i_size, h_size, False),
|
||||
nn.Linear(h_size, h_size, False),
|
||||
@ -87,27 +109,32 @@ class Net(nn.Module):
|
||||
)
|
||||
|
||||
self._weight_pos_enc_and_mask = None
|
||||
|
||||
self.apply(xavier_init)
|
||||
self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)
|
||||
|
||||
@property
|
||||
def _weight_pos_enc(self):
|
||||
if self._weight_pos_enc_and_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
weight_matrix = []
|
||||
with torch.no_grad():
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
x = next(layer.parameters())
|
||||
weight_matrix.append(
|
||||
torch.cat(
|
||||
(
|
||||
# Those are the weights
|
||||
torch.full((x.numel(), 1), 0, device=d),
|
||||
torch.full((x.numel(), 1), 0, device=d, requires_grad=False),
|
||||
# Layer enumeration
|
||||
torch.full((x.numel(), 1), layer_id, device=d),
|
||||
torch.full((x.numel(), 1), layer_id, device=d, requires_grad=False),
|
||||
# Cell Enumeration
|
||||
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
|
||||
torch.arange(layer.out_features, device=d, requires_grad=False
|
||||
).repeat_interleave(layer.in_features).view(-1, 1),
|
||||
# Weight Enumeration within the Cells
|
||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1),
|
||||
*(torch.full((x.numel(), 1), 0, device=d) for _ in range(self.input_size-4))
|
||||
torch.arange(layer.in_features, device=d, requires_grad=False
|
||||
).view(-1, 1).repeat(layer.out_features, 1),
|
||||
*(torch.full((x.numel(), 1), 0, device=d, requires_grad=False
|
||||
) for _ in range(self.input_size-4))
|
||||
), dim=1)
|
||||
)
|
||||
# Finalize
|
||||
@ -115,16 +142,17 @@ class Net(nn.Module):
|
||||
|
||||
# Normalize 1,2,3 column of dim 1
|
||||
last_pos_idx = self.input_size - 4
|
||||
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8
|
||||
max_per_col, _ = weight_matrix[:, 1:-last_pos_idx].max(keepdim=True, dim=0)
|
||||
max_per_col += 1e-8
|
||||
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / max_per_col) + 1e-8
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(weight_matrix)
|
||||
mask = torch.ones_like(weight_matrix, requires_grad=False)
|
||||
mask[:, 0] = 0
|
||||
|
||||
self._weight_pos_enc_and_mask = weight_matrix, mask
|
||||
return tuple(x.clone() for x in self._weight_pos_enc_and_mask)
|
||||
self._weight_pos_enc_and_mask = weight_matrix.detach(), mask.detach()
|
||||
return self._weight_pos_enc_and_mask
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
@ -144,30 +172,33 @@ class Net(nn.Module):
|
||||
|
||||
def input_weight_matrix(self) -> Tensor:
|
||||
""" Calculating the input tensor formed from the weights of the net """
|
||||
with torch.no_grad():
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
pos_enc, mask = self._weight_pos_enc
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
|
||||
return weight_matrix.detach()
|
||||
|
||||
def target_weight_matrix(self) -> Tensor:
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
return weight_matrix
|
||||
|
||||
def self_train(self,
|
||||
training_steps: int,
|
||||
log_step_size: int = 0,
|
||||
learning_rate: float = 0.0004,
|
||||
save_history: bool = True
|
||||
save_history: bool = False,
|
||||
reduction: str = 'mean'
|
||||
) -> (Tensor, list):
|
||||
""" Training a network to predict its own weights in order to self-replicate. """
|
||||
|
||||
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
|
||||
|
||||
for training_step in range(training_steps):
|
||||
self.number_trained += 1
|
||||
optimizer.zero_grad()
|
||||
self.optimizer.zero_grad()
|
||||
input_data = self.input_weight_matrix()
|
||||
target_data = self.create_target_weights(input_data)
|
||||
output = self(input_data)
|
||||
loss = F.mse_loss(output, target_data)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
self.optimizer.step()
|
||||
|
||||
if save_history:
|
||||
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
|
||||
@ -184,15 +215,15 @@ class Net(nn.Module):
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# Saving weights only at the end of a soup/mixed exp. epoch.
|
||||
if save_history:
|
||||
if "soup" in self.name or "mixed" in self.name:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
self.trained = True
|
||||
return loss, self.loss_history
|
||||
return loss.detach(), self.loss_history
|
||||
|
||||
def self_application(self, SA_steps: int, log_step_size: Union[int, None] = None):
|
||||
""" Inputting the weights of a network to itself for a number of steps, without backpropagation. """
|
||||
@ -291,15 +322,14 @@ class SecondaryNet(Net):
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface):
|
||||
def __init__(self, name, interface, weight_interface=5, weight_hidden_size=2, weight_output_size=1):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.weight_interface = 5
|
||||
self.net_hidden_size = 4
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList()
|
||||
self.meta_weight_list.extend(
|
||||
self.weight_interface = weight_interface
|
||||
self.net_hidden_size = weight_hidden_size
|
||||
self.net_ouput_size = weight_output_size
|
||||
self.meta_weight_list = nn.ModuleList(
|
||||
[Net(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
|
||||
) for weight_idx in range(self.interface)]
|
||||
@ -310,20 +340,21 @@ class MetaCell(nn.Module):
|
||||
def _bed_mask(self):
|
||||
if self.__bed_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
embedding = torch.zeros(1, self.weight_interface, device=d)
|
||||
embedding = torch.zeros(1, self.weight_interface, device=d, requires_grad=False)
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(embedding)
|
||||
mask = torch.ones_like(embedding, requires_grad=False, device=d)
|
||||
mask[:, -1] = 0
|
||||
|
||||
self.__bed_mask = embedding, mask
|
||||
return tuple(x.clone() for x in self.__bed_mask)
|
||||
return self.__bed_mask
|
||||
|
||||
def forward(self, x):
|
||||
embedding, mask = self._bed_mask
|
||||
expanded_mask = mask.expand(*x.shape, embedding.shape[-1])
|
||||
embedding = embedding.repeat(*x.shape, 1)
|
||||
embedding = embedding.expand(*x.shape, embedding.shape[-1])
|
||||
# embedding = embedding.repeat(*x.shape, 1)
|
||||
|
||||
# Row-wise
|
||||
# xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1]).swapdims(0, 1)
|
||||
@ -340,18 +371,40 @@ class MetaCell(nn.Module):
|
||||
def particles(self):
|
||||
return (net for net in self.meta_weight_list)
|
||||
|
||||
def make_particles_attack(self, ratio=0.01):
|
||||
random_particle_list = list(self.particles)
|
||||
random.shuffle(random_particle_list)
|
||||
for idx, particle in enumerate(self.particles):
|
||||
if random.random() <= ratio:
|
||||
other = random_particle_list[idx]
|
||||
if other != particle:
|
||||
particle.attack(other)
|
||||
|
||||
def make_particles_melt(self, ratio=0.01):
|
||||
random_particle_list = list(self.particles)
|
||||
random.shuffle(random_particle_list)
|
||||
for idx, particle in enumerate(self.particles):
|
||||
if random.random() <= ratio:
|
||||
other = random_particle_list[idx]
|
||||
if other != particle:
|
||||
new_particle = particle.melt(other)
|
||||
particle.apply_weights(new_particle.target_weight_matrix())
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, width=4, residual_skip=True):
|
||||
def __init__(self, name, interface=4, width=4, # residual_skip=False,
|
||||
weight_interface=5, weight_hidden_size=2, weight_output_size=1):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.residual_skip = False
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
self.meta_cell_list = nn.ModuleList()
|
||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}',
|
||||
interface=interface
|
||||
self.meta_cell_list = nn.ModuleList([
|
||||
MetaCell(name=f'{self.name}_C{cell_idx}',
|
||||
interface=interface,
|
||||
weight_interface=weight_interface, weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size,
|
||||
) for cell_idx in range(self.width)]
|
||||
)
|
||||
|
||||
@ -371,26 +424,41 @@ class MetaLayer(nn.Module):
|
||||
|
||||
class MetaNet(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None):
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True, dropout=0,
|
||||
weight_interface=5, weight_hidden_size=2, weight_output_size=1,):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.dropout = dropout
|
||||
self.activation = activation
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
|
||||
self._meta_layer_list = nn.ModuleList()
|
||||
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
|
||||
self.weight_interface = weight_interface
|
||||
self.weight_hidden_size = weight_hidden_size
|
||||
self.weight_output_size = weight_output_size
|
||||
self._meta_layer_first = MetaLayer(name=f'L{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)
|
||||
)
|
||||
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
width=self.width,
|
||||
weight_interface=weight_interface,
|
||||
weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size)
|
||||
|
||||
self._meta_layer_list = nn.ModuleList([MetaLayer(name=f'L{layer_idx + 1}',
|
||||
interface=self.width, width=self.width,
|
||||
weight_interface=weight_interface,
|
||||
weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size,
|
||||
|
||||
) for layer_idx in range(self.depth - 2)]
|
||||
)
|
||||
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
|
||||
interface=self.width, width=self.out)
|
||||
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list) + 1}',
|
||||
interface=self.width, width=self.out,
|
||||
weight_interface=weight_interface,
|
||||
weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(p=self.dropout)
|
||||
|
||||
def replace_with_zero(self, ident_key):
|
||||
replaced_particles = 0
|
||||
@ -400,41 +468,184 @@ class MetaNet(nn.Module):
|
||||
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
|
||||
)
|
||||
replaced_particles += 1
|
||||
if replaced_particles != 0:
|
||||
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
tensor = x
|
||||
for meta_layer in self._meta_layer_list:
|
||||
tensor = self._meta_layer_first(x)
|
||||
residual = None
|
||||
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
|
||||
if idx % 2 == 1 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
residual = tensor
|
||||
tensor = meta_layer(tensor)
|
||||
if idx % 2 == 0 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
tensor = tensor + residual
|
||||
tensor = self._meta_layer_last(tensor)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
|
||||
return (cell for metalayer in self.all_layers for cell in metalayer.particles)
|
||||
|
||||
def combined_self_train(self, n_st_steps, reduction='mean', per_particle=True, alpha=1):
|
||||
|
||||
def combined_self_train(self, external_optimizer):
|
||||
losses = []
|
||||
|
||||
if per_particle:
|
||||
for particle in self.particles:
|
||||
loss, _ = particle.self_train(n_st_steps, reduction=reduction)
|
||||
losses.append(loss.detach())
|
||||
else:
|
||||
optim = torch.optim.SGD(self.parameters(), lr=0.004, momentum=0.9)
|
||||
for _ in range(n_st_steps):
|
||||
optim.zero_grad()
|
||||
train_losses = []
|
||||
for particle in self.particles:
|
||||
# Zero your gradients for every batch!
|
||||
external_optimizer.zero_grad()
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
loss = F.mse_loss(output, target_data)
|
||||
losses.append(loss.detach)
|
||||
loss.backward()
|
||||
# Adjust learning weights
|
||||
external_optimizer.step()
|
||||
# return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
return sum(losses)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
|
||||
train_losses.append(loss)
|
||||
train_losses = torch.hstack(train_losses).sum(dim=-1, keepdim=True)
|
||||
if alpha not in [0, 1]:
|
||||
train_losses *= alpha
|
||||
train_losses.backward()
|
||||
optim.step()
|
||||
losses.append(train_losses.detach())
|
||||
losses = torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
return losses
|
||||
|
||||
@property
|
||||
def hyperparams(self):
|
||||
return {key: val for key, val in self.__dict__.items() if not key.startswith('_')}
|
||||
|
||||
def replace_particles(self, particle_weights_list):
|
||||
for layer in self.all_layers:
|
||||
for cell in layer.meta_cell_list:
|
||||
# Individual replacement on cell lvl
|
||||
for weight in cell.meta_weight_list:
|
||||
weight.apply_weights(next(particle_weights_list).detach())
|
||||
return self
|
||||
|
||||
def make_particles_attack(self, ratio=0.01, level=NetworkLevel.cell, reduction='mean'):
|
||||
if level == NetworkLevel.all:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.layer:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.cell:
|
||||
for layer in self.all_layers:
|
||||
for cell in layer.meta_cell_list:
|
||||
cell.make_particles_attack(ratio)
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(f'level has to be any of: {[level]}')
|
||||
# Self Train Loss after attack:
|
||||
with torch.no_grad():
|
||||
sa_losses = []
|
||||
for particle in self.particles:
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
|
||||
sa_losses.append(loss)
|
||||
after_attack_loss = torch.hstack(sa_losses).sum(dim=-1, keepdim=True)
|
||||
return after_attack_loss
|
||||
|
||||
def make_particles_melt(self, ratio=0.01, level=NetworkLevel.cell, reduction='mean'):
|
||||
if level == NetworkLevel.all:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.layer:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.cell:
|
||||
for layer in self.all_layers:
|
||||
for cell in layer.meta_cell_list:
|
||||
cell.make_particles_melt(ratio)
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(f'level has to be any of: {[level]}')
|
||||
# Self Train Loss after attack:
|
||||
with torch.no_grad():
|
||||
sa_losses = []
|
||||
for particle in self.particles:
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
|
||||
sa_losses.append(loss)
|
||||
after_melt_loss = torch.hstack(sa_losses).sum(dim=-1, keepdim=True)
|
||||
return after_melt_loss
|
||||
|
||||
|
||||
@property
|
||||
def all_layers(self):
|
||||
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
|
||||
|
||||
@property
|
||||
def particle_parameter_count(self):
|
||||
return sum(p.numel() for p in next(self.particles).parameters())
|
||||
|
||||
def count_fixpoints(self, fix_type=FixTypes.identity_func):
|
||||
return sum(x.is_fixpoint == fix_type for x in self.particles)
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for particle in self.particles:
|
||||
if particle.is_fixpoint == FixTypes.divergent:
|
||||
particle.apply(xavier_init)
|
||||
|
||||
|
||||
class MetaNetCompareBaseline(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.activation = activation
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self._first_layer = nn.Linear(self.interface, self.width, bias=False)
|
||||
self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False
|
||||
) for _ in range(self.depth - 2)])
|
||||
self._last_layer = nn.Linear(self.width, self.out, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self._first_layer(x)
|
||||
residual = None
|
||||
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
|
||||
if idx % 2 == 1 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
residual = tensor
|
||||
tensor = meta_layer(tensor)
|
||||
if idx % 2 == 0 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
tensor = tensor + residual
|
||||
tensor = self._last_layer(tensor)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def all_layers(self):
|
||||
return (x for x in (self._first_layer, *self._meta_layer_list, self._last_layer))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
|
||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1, residual_skip=True)
|
||||
next(metanet.particles).input_weight_matrix()
|
||||
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
|
||||
metanet(torch.hstack([torch.full((2, 1), 1.0) for _ in range(metanet.interface)]))
|
||||
a = metanet.particles
|
||||
print('Test')
|
||||
print('Test')
|
||||
|
119
plot_3d_trajectories.py
Normal file
@ -0,0 +1,119 @@
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from network import FixTypes
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
|
||||
def plot_single_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):
|
||||
"""
|
||||
This plots one PCA for every net (over its n epochs) as one trajectory
|
||||
and then combines all of them in one plot
|
||||
"""
|
||||
model = torch.load(model_path, map_location=torch.device('cpu')).eval()
|
||||
all_weights = pd.read_csv(all_weights_path, index_col=False)
|
||||
save_path = model_path.parent / 'trajec_plots'
|
||||
|
||||
all_epochs = all_weights.Epoch.unique()
|
||||
pca = PCA(n_components=2, whiten=True)
|
||||
save_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for layer_idx, model_layer in enumerate(model.all_layers):
|
||||
|
||||
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
|
||||
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
|
||||
if num_status_of_layer != 0:
|
||||
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
|
||||
weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
|
||||
for name in layer.Weight.unique()]
|
||||
plt.clf()
|
||||
fig = plt.figure()
|
||||
ax = plt.axes(projection='3d')
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
plt.tight_layout()
|
||||
|
||||
for weights_of_net, status in zip(weight_batches, fixpoint_statuses):
|
||||
if status == status_type:
|
||||
pca.fit(weights_of_net)
|
||||
transformed_trajectory = pca.transform(weights_of_net)
|
||||
xdata = transformed_trajectory[:, 0]
|
||||
ydata = transformed_trajectory[:, 1]
|
||||
zdata = all_epochs
|
||||
ax.plot3D(xdata, ydata, zdata)
|
||||
ax.scatter(xdata, ydata, zdata, s=7)
|
||||
|
||||
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
|
||||
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
|
||||
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
|
||||
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
|
||||
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png"
|
||||
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
|
||||
plt.clf()
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_grouped_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):
|
||||
""" This computes the PCA over all the net-weights at once and then plots that."""
|
||||
|
||||
model = torch.load(model_path, map_location=torch.device('cpu')).eval()
|
||||
save_path = model_path.parent / 'trajec_plots'
|
||||
all_weights = pd.read_csv(all_weights_path, index_col=False)
|
||||
all_epochs = all_weights.Epoch.unique()
|
||||
pca = PCA(n_components=2, whiten=True)
|
||||
save_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for layer_idx, model_layer in enumerate(model.all_layers):
|
||||
|
||||
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
|
||||
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
|
||||
if num_status_of_layer != 0:
|
||||
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
|
||||
weight_batches = np.vstack([np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
|
||||
for name in layer.Weight.unique()])
|
||||
plt.clf()
|
||||
fig = plt.figure()
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
ax = plt.axes(projection='3d')
|
||||
plt.tight_layout()
|
||||
|
||||
pca.fit(weight_batches)
|
||||
w_transformed = pca.transform(weight_batches)
|
||||
for transformed_trajectory, status in zip(
|
||||
np.split(w_transformed, len(layer.Weight.unique())), fixpoint_statuses):
|
||||
if status == status_type:
|
||||
xdata = transformed_trajectory[:, 0]
|
||||
ydata = transformed_trajectory[:, 1]
|
||||
zdata = all_epochs
|
||||
ax.plot3D(xdata, ydata, zdata)
|
||||
ax.scatter(xdata, ydata, zdata, s=7)
|
||||
|
||||
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
|
||||
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
|
||||
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
|
||||
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
|
||||
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}_grouped.png"
|
||||
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
|
||||
plt.clf()
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise (NotImplementedError('Get out of here'))
|
||||
"""
|
||||
weight_path = Path("weight_store.csv")
|
||||
model_path = Path("trained_model_ckpt_e100.tp")
|
||||
save_path = Path("figures/3d_trajectories/")
|
||||
|
||||
weight_df = pd.read_csv(weight_path)
|
||||
weight_df = weight_df.drop_duplicates(subset=['Weight','Epoch'])
|
||||
model = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
|
||||
plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.identity_func)
|
||||
plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.other_func)
|
||||
plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.identity_func)
|
||||
#plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.other_func)
|
||||
"""
|
22
sanity_check_particle_weight_swap.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from network import MetaNet
|
||||
from sparse_net import SparseNetwork
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True,
|
||||
weight_hidden_size=3, )
|
||||
sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True,
|
||||
weight_hidden_size=3,)
|
||||
|
||||
particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||
|
||||
# Transfer weights
|
||||
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
|
||||
|
||||
# Transfer weights
|
||||
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
||||
new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||
|
||||
print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}')
|
71
sanity_check_weights.py
Normal file
@ -0,0 +1,71 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST, CIFAR10
|
||||
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
|
||||
import torchmetrics
|
||||
|
||||
from functionalities_test import epsilon_error_margin as e
|
||||
from network import MetaNet, MetaNetCompareBaseline
|
||||
|
||||
|
||||
def extract_weights_from_model(model: MetaNet) -> dict:
|
||||
inpt = torch.zeros(5, device=next(model.parameters()).device, dtype=torch.float)
|
||||
inpt[-1] = 1
|
||||
|
||||
weights = defaultdict(list)
|
||||
layers = [layer.particles for layer in model.all_layers]
|
||||
for i, layer in enumerate(layers):
|
||||
for net in layer:
|
||||
weights[i].append(net(inpt).detach())
|
||||
return dict(weights)
|
||||
|
||||
|
||||
def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics.Accuracy):
|
||||
meta_net_device = next(meta_net.parameters()).device
|
||||
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth,
|
||||
width=meta_net.width, out=meta_net.out,
|
||||
residual_skip=meta_net.residual_skip).to(meta_net_device)
|
||||
with torch.no_grad():
|
||||
new_weight_values = list(new_weights.values())
|
||||
old_parameters = list(transfer_net.parameters())
|
||||
assert len(new_weight_values) == len(old_parameters)
|
||||
for weights, parameters in zip(new_weights.values(), transfer_net.parameters()):
|
||||
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
|
||||
|
||||
transfer_net.eval()
|
||||
results = dict()
|
||||
for net in [meta_net, transfer_net]:
|
||||
net.eval()
|
||||
metric = metric_class()
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='Test Batch: '):
|
||||
y = net(batch_x.to(meta_net_device))
|
||||
metric(y.cpu(), batch_y.cpu())
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
measure = metric.compute()
|
||||
results[net.__class__.__name__] = measure.item()
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
WORKER = 0
|
||||
BATCHSIZE = 500
|
||||
MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Flatten(start_dim=0)])
|
||||
torch.manual_seed(42)
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
|
||||
model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
test_weights_as_model(model, weights, d_test)
|
||||
|
385
sparse_net.py
Normal file
@ -0,0 +1,385 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
from torch import nn
|
||||
|
||||
import functionalities_test
|
||||
from network import Net
|
||||
from functionalities_test import is_identity_function, test_for_fixpoints, epsilon_error_margin
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
|
||||
|
||||
def xavier_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
return nn.init.xavier_uniform_(m.weight.data)
|
||||
if isinstance(m, torch.Tensor):
|
||||
return nn.init.xavier_uniform_(m)
|
||||
|
||||
|
||||
class SparseLayer(nn.Module):
|
||||
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
|
||||
super(SparseLayer, self).__init__()
|
||||
|
||||
self.nr_nets = nr_nets
|
||||
self.interface_dim = interface
|
||||
self.depth_dim = depth
|
||||
self.hidden_dim = width
|
||||
self.out_dim = out
|
||||
dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)
|
||||
self.dummy_net_shapes = [list(x.shape) for x in dummy_net.parameters()]
|
||||
self.dummy_net_weight_pos_enc = dummy_net._weight_pos_enc
|
||||
|
||||
self.sparse_sub_layer = list()
|
||||
self.indices = list()
|
||||
self.diag_shapes = list()
|
||||
self.weights = nn.ParameterList()
|
||||
self._particles = None
|
||||
|
||||
for layer_id in range(self.depth_dim):
|
||||
indices, weights, diag_shape = self.coo_sparse_layer(layer_id)
|
||||
self.indices.append(indices)
|
||||
self.diag_shapes.append(diag_shape)
|
||||
self.weights.append(weights)
|
||||
self.apply(xavier_init)
|
||||
|
||||
def coo_sparse_layer(self, layer_id):
|
||||
with torch.no_grad():
|
||||
layer_shape = self.dummy_net_shapes[layer_id]
|
||||
sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)
|
||||
indices = torch.Tensor(np.argwhere(sparse_diagonal == 1).T, )
|
||||
values = torch.nn.Parameter(torch.randn((np.prod((*layer_shape, self.nr_nets)).item())), requires_grad=True)
|
||||
|
||||
return indices, values, sparse_diagonal.shape
|
||||
|
||||
def get_self_train_inputs_and_targets(self):
|
||||
# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN
|
||||
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2
|
||||
# and first hidden*out weights of layer3 = first net
|
||||
# [nr_layers*[nr_net*nr_weights_layer_i]]
|
||||
with torch.no_grad():
|
||||
weights = [layer.view(-1, int(len(layer)/self.nr_nets)).detach() for layer in self.weights]
|
||||
# [nr_net*[nr_weights]]
|
||||
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1, 1) for i in range(self.nr_nets)]
|
||||
# (16, 25)
|
||||
|
||||
encoding_matrix, mask = self.dummy_net_weight_pos_enc
|
||||
weight_device = weights_per_net[0].device
|
||||
if weight_device != encoding_matrix.device or weight_device != mask.device:
|
||||
encoding_matrix, mask = encoding_matrix.to(weight_device), mask.to(weight_device)
|
||||
self.dummy_net_weight_pos_enc = encoding_matrix, mask
|
||||
|
||||
inputs = torch.hstack(
|
||||
[encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask)
|
||||
for i in range(self.nr_nets)]
|
||||
)
|
||||
targets = torch.hstack(weights_per_net)
|
||||
return inputs.T, targets.T
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
if self._particles is None:
|
||||
self._particles = [Net(self.interface_dim, self.hidden_dim, self.out_dim) for _ in range(self.nr_nets)]
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
# Particle Weight Update
|
||||
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
weights_per_net = [torch.cat([layer[i] for layer in all_weights]).view(-1, 1) for i in
|
||||
range(self.nr_nets)] # [nr_net*[nr_weights]]
|
||||
for particles, weights in zip(self._particles, weights_per_net):
|
||||
particles.apply_weights(weights)
|
||||
return self._particles
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for weights in self.weights:
|
||||
if torch.isinf(weights).any() or torch.isnan(weights).any():
|
||||
with torch.no_grad():
|
||||
where_nan = torch.nan_to_num(weights, -99, -99, -99)
|
||||
mask = torch.where(where_nan == -99, 0, 1)
|
||||
weights[:] = (where_nan * mask + torch.randn_like(weights) * (1 - mask))[:]
|
||||
|
||||
@property
|
||||
def particle_weights(self):
|
||||
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
weights_per_net = [torch.cat([layer[i] for layer in all_weights]).view(-1, 1) for i in
|
||||
range(self.nr_nets)] # [nr_net*[nr_weights]]
|
||||
return weights_per_net
|
||||
|
||||
def replace_weights_by_particles(self, particles):
|
||||
assert len(particles) == self.nr_nets
|
||||
with torch.no_grad():
|
||||
# Particle Weight Update
|
||||
all_weights = [list(particle.parameters()) for particle in particles]
|
||||
all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)]
|
||||
# [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
for weights, parameters in zip(all_weights, self.parameters()):
|
||||
parameters[:] = weights[:]
|
||||
return self
|
||||
|
||||
def __call__(self, x):
|
||||
for indices, diag_shapes, weights in zip(self.indices, self.diag_shapes, self.weights):
|
||||
s = torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
|
||||
x = torch.sparse.mm(s, x)
|
||||
return x
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super(SparseLayer, self).to(*args, **kwargs)
|
||||
self.sparse_sub_layer = [sparse_sub_layer.to(*args, **kwargs) for sparse_sub_layer in self.sparse_sub_layer]
|
||||
return self
|
||||
|
||||
|
||||
def test_sparse_layer():
|
||||
net = SparseLayer(1000)
|
||||
loss_fn = torch.nn.MSELoss(reduction='mean')
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9)
|
||||
# optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
|
||||
df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
|
||||
train_iterations = 20000
|
||||
|
||||
for train_iteration in trange(train_iterations):
|
||||
optimizer.zero_grad()
|
||||
X, Y = net.get_self_train_inputs_and_targets()
|
||||
output = net(X)
|
||||
|
||||
loss = loss_fn(output, Y) * 100
|
||||
|
||||
# loss = sum([loss_fn(out, target) for out, target in zip(output, Y)]) / len(output) * 10
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if train_iteration % 500 == 0:
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {train_iteration + 1} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
df.loc[df.shape[0]] = (train_iteration, key, value)
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {train_iterations} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
df.loc[df.shape[0]] = (train_iterations, key, value)
|
||||
df.to_csv('counter.csv', mode='w')
|
||||
|
||||
c = pd.read_csv('counter.csv', index_col=0)
|
||||
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
|
||||
plt.savefig('counter.png', dpi=300)
|
||||
|
||||
|
||||
def embed_batch(x, repeat_dim):
|
||||
# x of shape (batchsize, flat_img_dim)
|
||||
|
||||
# (batchsize, flat_img_dim, 1)
|
||||
x = x.unsqueeze(-1)
|
||||
# (batchsize, flat_img_dim, encoding_dim*repeat_dim)
|
||||
# torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
|
||||
return torch.cat((torch.zeros(x.shape[0], x.shape[1], 4, device=x.device), x), dim=2).repeat(1, 1, repeat_dim)
|
||||
|
||||
def embed_vector(x, repeat_dim):
|
||||
# x of shape [flat_img_dim]
|
||||
x = x.unsqueeze(-1) # (flat_img_dim, 1)
|
||||
# (flat_img_dim, encoding_dim*repeat_dim)
|
||||
return torch.cat((torch.zeros(x.shape[0], 4), x), dim=1).repeat(1,repeat_dim)
|
||||
|
||||
|
||||
class SparseNetwork(nn.Module):
|
||||
|
||||
@property
|
||||
def nr_nets(self):
|
||||
return sum(x.nr_nets for x in self.sparselayers)
|
||||
|
||||
def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
|
||||
weight_interface=5, weight_hidden_size=2, weight_output_size=1
|
||||
):
|
||||
super(SparseNetwork, self).__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.input_dim = input_dim
|
||||
self.depth_dim = depth
|
||||
self.hidden_dim = width
|
||||
self.out_dim = out
|
||||
self.activation = activation
|
||||
self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
|
||||
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
|
||||
self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
|
||||
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
|
||||
self.hidden_layers = nn.ModuleList([
|
||||
SparseLayer(self.hidden_dim * self.hidden_dim,
|
||||
interface=weight_interface, width=weight_hidden_size, out=weight_output_size
|
||||
) for _ in range(self.depth_dim - 2)])
|
||||
|
||||
def __call__(self, x):
|
||||
|
||||
tensor = self.sparse_layer_forward(x, self.first_layer)
|
||||
if self.activation:
|
||||
tensor = self.activation(tensor)
|
||||
for nl_idx, network_layer in enumerate(self.hidden_layers):
|
||||
# if idx % 2 == 1 and self.residual_skip:
|
||||
if self.residual_skip:
|
||||
residual = tensor
|
||||
tensor = self.sparse_layer_forward(tensor, network_layer)
|
||||
# if idx % 2 == 0 and self.residual_skip:
|
||||
if self.residual_skip:
|
||||
tensor = tensor + residual
|
||||
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
|
||||
return tensor
|
||||
|
||||
def sparse_layer_forward(self, x, sparse_layer, view_dim=None):
|
||||
view_dim = view_dim if view_dim else self.hidden_dim
|
||||
# batch pass (one by one, sparse bmm doesn't support grad)
|
||||
if len(x.shape) > 1:
|
||||
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
|
||||
# [batchsize, hidden*inpt_dim, feature_dim]
|
||||
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1) for inpt in
|
||||
embedded_inpt])
|
||||
# vector
|
||||
else:
|
||||
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
|
||||
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1)
|
||||
return x
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
#particles = []
|
||||
#particles.extend(self.first_layer.particles)
|
||||
#for layer in self.hidden_layers:
|
||||
# particles.extend(layer.particles)
|
||||
#particles.extend(self.last_layer.particles)
|
||||
return (x for y in (self.first_layer.particles,
|
||||
*(l.particles for l in self.hidden_layers),
|
||||
self.last_layer.particles) for x in y)
|
||||
|
||||
@property
|
||||
def particle_weights(self):
|
||||
return (x for y in self.sparselayers for x in y.particle_weights)
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for layer in self.sparselayers:
|
||||
layer.reset_diverged_particles()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super(SparseNetwork, self).to(*args, **kwargs)
|
||||
self.first_layer = self.first_layer.to(*args, **kwargs)
|
||||
self.last_layer = self.last_layer.to(*args, **kwargs)
|
||||
self.hidden_layers = nn.ModuleList([hidden_layer.to(*args, **kwargs) for hidden_layer in self.hidden_layers])
|
||||
return self
|
||||
|
||||
@property
|
||||
def sparselayers(self):
|
||||
return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
|
||||
|
||||
def combined_self_train(self, optimizer, reduction='mean'):
|
||||
losses = []
|
||||
loss_fn = nn.MSELoss(reduction=reduction)
|
||||
for layer in self.sparselayers:
|
||||
optimizer.zero_grad()
|
||||
x, target_data = layer.get_self_train_inputs_and_targets()
|
||||
output = layer(x)
|
||||
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
|
||||
|
||||
loss = loss_fn(output, target_data) * layer.nr_nets
|
||||
|
||||
losses.append(loss.detach())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def replace_weights_by_particles(self, particles):
|
||||
particles = list(particles)
|
||||
for layer in self.sparselayers:
|
||||
layer.replace_weights_by_particles(particles[:layer.nr_nets])
|
||||
del particles[:layer.nr_nets]
|
||||
return self
|
||||
|
||||
|
||||
def test_sparse_net():
|
||||
utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])
|
||||
data_path = Path('data')
|
||||
WORKER = 8
|
||||
BATCHSIZE = 10
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
data_dim = np.prod(dataset[0][0].shape)
|
||||
metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)
|
||||
batchx, batchy = next(iter(d))
|
||||
out = metanet(batchx)
|
||||
|
||||
result = sum([torch.allclose(out[i], batchy[i], rtol=0, atol=epsilon_error_margin) for i in range(metanet.nr_nets)])
|
||||
# print(f"identity_fn after {train_iteration+1} self-train iterations: {result} /{net.nr_nets}")
|
||||
|
||||
|
||||
def test_sparse_net_sef_train():
|
||||
sparse_metanet = SparseNetwork(15*15, 5, 6, 10).to('cuda')
|
||||
init_st_store_path = Path('counter.csv')
|
||||
optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
init_st_epochs = 10000
|
||||
init_st_df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
|
||||
|
||||
for st_epoch in trange(init_st_epochs):
|
||||
_ = sparse_metanet.combined_self_train(optimizer)
|
||||
|
||||
if st_epoch % 500 == 0:
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {st_epoch} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
init_st_df.loc[init_st_df.shape[0]] = (st_epoch, key, value)
|
||||
sparse_metanet.reset_diverged_particles()
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {init_st_epochs} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
init_st_df.loc[init_st_df.shape[0]] = (init_st_epochs, key, value)
|
||||
init_st_df.to_csv(init_st_store_path, mode='w', index=False)
|
||||
|
||||
c = pd.read_csv(init_st_store_path)
|
||||
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
|
||||
plt.savefig(init_st_store_path, dpi=300)
|
||||
|
||||
|
||||
def test_manual_for_loop():
|
||||
nr_nets = 500
|
||||
nets = [Net(5,2,1) for _ in range(nr_nets)]
|
||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||
rounds = 1000
|
||||
|
||||
for net in tqdm(nets):
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
|
||||
for i in range(rounds):
|
||||
optimizer.zero_grad()
|
||||
input_data = net.input_weight_matrix()
|
||||
target_data = net.create_target_weights(input_data)
|
||||
output = net(input_data)
|
||||
loss = loss_fn(output, target_data)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
sum([is_identity_function(net) for net in nets])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_sparse_layer()
|
||||
test_sparse_net_sef_train()
|
||||
# test_sparse_net()
|
||||
# for comparison
|
||||
# test_manual_for_loop()
|
498
sparse_tensor_combined.ipynb
Normal file
@ -0,0 +1,498 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from network import Net\n",
|
||||
"import torch\n",
|
||||
"from typing import List\n",
|
||||
"from functionalities_test import is_identity_function\n",
|
||||
"from tqdm import tqdm,trange\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def construct_sparse_COO_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
||||
" assert layer_idx <= len(list(nets[0].parameters()))\n",
|
||||
" values = []\n",
|
||||
" indices = []\n",
|
||||
" for net_idx,net in enumerate(nets):\n",
|
||||
" layer = list(net.parameters())[layer_idx]\n",
|
||||
" \n",
|
||||
" for cell_idx,cell in enumerate(layer):\n",
|
||||
" # E.g., position of cell weights (with 2 cells per hidden layer) in first sparse layer of N nets: \n",
|
||||
" \n",
|
||||
" # [4x2 weights_net0] [4x2x(n-1) 0s]\n",
|
||||
" # [4x2 weights] [4x2 weights_net0] [4x2x(n-2) 0s]\n",
|
||||
" # ... etc\n",
|
||||
" # [4x2x(n-1) 0s] [4x2 weights_netN]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # -> 4x2 weights on the diagonal = [shifted Nr_cellss*B down for AxB cells, and Nr_nets(*A weights)to the right] \n",
|
||||
" for i in range(len(cell)):\n",
|
||||
" indices.append([len(layer)*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
||||
" #indices.append([2*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
||||
"\n",
|
||||
" [values.append(weight) for weight in cell]\n",
|
||||
"\n",
|
||||
" # for i in range(4):\n",
|
||||
" # indices.append([idx+idx+1, i+(idx*4)])\n",
|
||||
" #for l in next(net.parameters()):\n",
|
||||
" #[values.append(w) for w in l]\n",
|
||||
" #print(indices, values)\n",
|
||||
"\n",
|
||||
" #s = torch.sparse_coo_tensor(list(zip(*indices)), values, (2*nr_nets, 4*nr_nets))\n",
|
||||
" # sparse tensor dimension = (nr_cells*nr_nets , nr_weights/cell * nr_nets), i.e.,\n",
|
||||
" # layer 1: (2x4) -> (2*N, 4*N)\n",
|
||||
" # layer 2: (2x2) -> (2*N, 2*N)\n",
|
||||
" # layer 3: (1x2) -> (2*N, 1*N)\n",
|
||||
" s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets))\n",
|
||||
" #print(s.to_dense())\n",
|
||||
" #print(s.to_dense().shape)\n",
|
||||
" return s\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# for each net append to the combined sparse tensor\n",
|
||||
"# construct sparse tensor for each layer, with Nets of (4,2,1), each net appends\n",
|
||||
"# - [4x2] weights in the first (input) layer\n",
|
||||
"# - [2x2] weights in the second (hidden) layer\n",
|
||||
"# - [2x1] weights in the third (output) layer\n",
|
||||
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"#modules\n",
|
||||
"#for layer_idx in range(len(list(nets[0].parameters()))):\n",
|
||||
"# sparse_tensor = construct_sparse_tensor_layer(nets, layer_idx)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 50\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1000):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"\n",
|
||||
" X1 = torch.sparse.mm(modules[0], X)\n",
|
||||
" #print(\"X1\", X1.shape, X1)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(modules[1], X1)\n",
|
||||
" #print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(modules[2], X2)\n",
|
||||
" #print(\"X3\", X3.shape)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 500\n",
|
||||
"nets = [Net(5,2,1) for _ in range(nr_nets)]\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"rounds = 1000\n",
|
||||
"\n",
|
||||
"for net in tqdm(nets):\n",
|
||||
" optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)\n",
|
||||
" for i in range(rounds):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" input_data = net.input_weight_matrix()\n",
|
||||
" target_data = net.create_target_weights(input_data)\n",
|
||||
" output = net(input_data)\n",
|
||||
" loss = loss_fn(output, target_data)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"sum([is_identity_function(net) for net in nets])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def construct_sparse_CRS_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
||||
" assert layer_idx <= len(list(nets[0].parameters()))\n",
|
||||
" \n",
|
||||
" s = torch.cat( [\n",
|
||||
" torch.cat(\n",
|
||||
" (\n",
|
||||
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
|
||||
" list(net.parameters())[layer_idx], \n",
|
||||
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
|
||||
" )\n",
|
||||
" , dim=1) for net_idx, net in enumerate(nets)\n",
|
||||
" ]).to_sparse_csr()\n",
|
||||
"\n",
|
||||
" print(s.shape)\n",
|
||||
" return s"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" num_layers = len(list(nets[0].parameters()))\n",
|
||||
" modules = [ construct_sparse_CRS_layer(nets, layer_idx) for layer_idx in range(num_layers)]\n",
|
||||
"\n",
|
||||
" X1 = modules[0].matmul(X)\n",
|
||||
" print(\"X1\", X1.shape, X1.is_sparse)\n",
|
||||
"\n",
|
||||
" X2 = modules[1].matmul(X1)\n",
|
||||
" print(\"X2\", X2.shape, X2.is_sparse)\n",
|
||||
"\n",
|
||||
" X3 = modules[2].matmul(X2)\n",
|
||||
" print(\"X3\", X3.shape, X3.is_sparse)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 2\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"\n",
|
||||
"def cat_COO_layer(nets, layer_idx):\n",
|
||||
" i = [[0,i] for i in range(nr_nets*len(list(net.parameters())[layer_idx]))]\n",
|
||||
" v = torch.cat( [\n",
|
||||
" torch.cat(\n",
|
||||
" (\n",
|
||||
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
|
||||
" list(net.parameters())[layer_idx], \n",
|
||||
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
|
||||
" )\n",
|
||||
" , dim=1) for net_idx, net in enumerate(nets)\n",
|
||||
" ])\n",
|
||||
" #print(i,v)\n",
|
||||
" s = torch.sparse_coo_tensor(list(zip(*i)), v)\n",
|
||||
" print(s[0].to_dense().shape, s[0].is_sparse)\n",
|
||||
" return s[0]\n",
|
||||
"\n",
|
||||
"cat_COO_layer(nets, 0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"nr_layers = len(list(nets[0].parameters()))\n",
|
||||
"modules = [ cat_COO_layer(nets, layer_idx) for layer_idx in range(nr_layers) ]\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" X1 = torch.sparse.mm(modules[0], X)\n",
|
||||
" print(\"X1\", X1.shape)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(modules[1], X1)\n",
|
||||
" print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(modules[2], X2)\n",
|
||||
" print(\"X3\", X3.shape)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SparseLayer():\n",
|
||||
" def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):\n",
|
||||
" self.nr_nets = nr_nets\n",
|
||||
" self.interface_dim = interface\n",
|
||||
" self.depth_dim = depth\n",
|
||||
" self.hidden_dim = width\n",
|
||||
" self.out_dim = out\n",
|
||||
" self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)\n",
|
||||
" \n",
|
||||
" self.sparse_sub_layer = []\n",
|
||||
" self.weights = []\n",
|
||||
" for layer_id in range(depth):\n",
|
||||
" layer, weights = self.coo_sparse_layer(layer_id)\n",
|
||||
" self.sparse_sub_layer.append(layer)\n",
|
||||
" self.weights.append(weights)\n",
|
||||
" \n",
|
||||
" def coo_sparse_layer(self, layer_id):\n",
|
||||
" layer_shape = list(self.dummy_net.parameters())[layer_id].shape\n",
|
||||
" #print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)\n",
|
||||
"\n",
|
||||
" sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)\n",
|
||||
" indices = np.argwhere(sparse_diagonal == 1).T\n",
|
||||
" values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) )))\n",
|
||||
" #values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] ))\n",
|
||||
" s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True)\n",
|
||||
" print(f\"L{layer_id}:\", s.shape)\n",
|
||||
" return s, values\n",
|
||||
"\n",
|
||||
" def get_self_train_inputs_and_targets(self):\n",
|
||||
" encoding_matrix, mask = self.dummy_net._weight_pos_enc\n",
|
||||
"\n",
|
||||
" # view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
|
||||
" # i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
|
||||
" weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]]\n",
|
||||
" weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]]\n",
|
||||
" inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25)\n",
|
||||
" targets = torch.hstack(weights_per_net)\n",
|
||||
" return inputs.T, targets.T\n",
|
||||
"\n",
|
||||
" def __call__(self, x):\n",
|
||||
" X1 = torch.sparse.mm(self.sparse_sub_layer[0], x)\n",
|
||||
" #print(\"X1\", X1.shape)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1)\n",
|
||||
" #print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2)\n",
|
||||
" #print(\"X3\", X3.shape)\n",
|
||||
" \n",
|
||||
" return X3\n",
|
||||
"\n",
|
||||
"net = SparseLayer(5)\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"for train_iteration in trange(10):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X,Y = net.get_self_train_inputs_and_targets()\n",
|
||||
" out = net(X)\n",
|
||||
" \n",
|
||||
" loss = loss_fn(out, Y)\n",
|
||||
"\n",
|
||||
" # print(\"X:\", X.shape, \"Y:\", Y.shape)\n",
|
||||
" # print(\"OUT\", out.shape)\n",
|
||||
" # print(\"LOSS\", loss.item())\n",
|
||||
" \n",
|
||||
" loss.backward(retain_graph=True)\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"epsilon=pow(10, -5)\n",
|
||||
"# is the (the whole layer) self-replicating? -> wrong\n",
|
||||
"#print(torch.allclose(out, Y,rtol=0, atol=epsilon))\n",
|
||||
"\n",
|
||||
"# is each of the networks self-replicating?\n",
|
||||
"print(f\"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for layer in net.weights:\n",
|
||||
"# n=int(len(layer)/net.nr_nets)\n",
|
||||
"# print( [layer[i:i+n] for i in range(0, len(layer), n)])\n",
|
||||
"\n",
|
||||
"encoding_matrix, mask = Net(5,2,1)._weight_pos_enc\n",
|
||||
"print(encoding_matrix, mask)\n",
|
||||
"# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
|
||||
"# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
|
||||
"weights = [layer.view(-1, int(len(layer)/net.nr_nets)) for layer in net.weights]\n",
|
||||
"weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(net.nr_nets)]\n",
|
||||
"\n",
|
||||
"inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(net.nr_nets)]) #16, 25\n",
|
||||
"\n",
|
||||
"targets = torch.hstack(weights_per_net)\n",
|
||||
"targets.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from pathlib import Path\n",
|
||||
"import torch\n",
|
||||
"from torch.nn import Flatten\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision.transforms import ToTensor, Compose, Resize\n",
|
||||
"from tqdm import tqdm, trange\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])\n",
|
||||
"data_path = Path('data')\n",
|
||||
"WORKER = 8\n",
|
||||
"BATCHSIZE = 10\n",
|
||||
"EPOCH = 1\n",
|
||||
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"\n",
|
||||
"dataset = MNIST(str(data_path), transform=utility_transforms)\n",
|
||||
"d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def embed_batch(x, repeat_dim):\n",
|
||||
" # x of shape (batchsize, flat_img_dim)\n",
|
||||
" x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)\n",
|
||||
" return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)\n",
|
||||
"\n",
|
||||
"def embed_vector(x, repeat_dim):\n",
|
||||
" # x of shape [flat_img_dim]\n",
|
||||
" x = x.unsqueeze(-1) #(flat_img_dim, 1)\n",
|
||||
" return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim)\n",
|
||||
"\n",
|
||||
"class SparseNetwork():\n",
|
||||
" def __init__(self, input_dim, depth, width, out):\n",
|
||||
" self.input_dim = input_dim\n",
|
||||
" self.depth_dim = depth\n",
|
||||
" self.hidden_dim = width\n",
|
||||
" self.out_dim = out\n",
|
||||
" self.sparse_layers = []\n",
|
||||
" self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim ))\n",
|
||||
" self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)])\n",
|
||||
" self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim ))\n",
|
||||
"\n",
|
||||
" def __call__(self, x):\n",
|
||||
" \n",
|
||||
" for sparse_layer in self.sparse_layers[:-1]:\n",
|
||||
" # batch pass (one by one, sparse bmm doesn't support grad)\n",
|
||||
" if len(x.shape) > 1:\n",
|
||||
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
|
||||
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
|
||||
" # vector\n",
|
||||
" else:\n",
|
||||
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
|
||||
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1)\n",
|
||||
" print(\"out\", x.shape)\n",
|
||||
" \n",
|
||||
" # output layer\n",
|
||||
" sparse_layer = self.sparse_layers[-1]\n",
|
||||
" if len(x.shape) > 1:\n",
|
||||
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
|
||||
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
|
||||
" else:\n",
|
||||
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
|
||||
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1)\n",
|
||||
" print(\"out\", x.shape)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"data_dim = np.prod(dataset[0][0].shape)\n",
|
||||
"metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)\n",
|
||||
"batchx, batchy = next(iter(d))\n",
|
||||
"batchx.shape, batchy.shape\n",
|
||||
"metanet(batchx)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "8bcba732c17ca4dacffea8ad1176c852d4229b36b9060a5f633fff752e5396ea"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.12 64-bit ('masterthesis': conda)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|