Compare commits
95 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
38a4af2baa | ||
![]() |
0ba3994325 | ||
![]() |
dd2458da4a | ||
![]() |
ce5a36c8f4 | ||
![]() |
b3d4987cb8 | ||
![]() |
f3ff4c9239 | ||
![]() |
69c904e156 | ||
![]() |
b6c8859081 | ||
![]() |
c52b398819 | ||
![]() |
16c08d04d4 | ||
![]() |
e167cc78c5 | ||
![]() |
926b27b4ef | ||
![]() |
78a919395b | ||
![]() |
a3a587476c | ||
![]() |
c0db8e19a3 | ||
![]() |
e1a5383c04 | ||
![]() |
9d8496a725 | ||
![]() |
5b2b5b5beb | ||
![]() |
3da00c793b | ||
![]() |
ebf133414c | ||
![]() |
0bc3b62340 | ||
![]() |
f0ad875e79 | ||
![]() |
bb12176f72 | ||
![]() |
78e9c4d520 | ||
![]() |
7166fa86ca | ||
![]() |
f033a2448e | ||
![]() |
2a710b40d7 | ||
![]() |
f25cee5203 | ||
![]() |
52081d176e | ||
![]() |
5aea4f9f55 | ||
![]() |
4f99251e68 | ||
![]() |
28b5e85697 | ||
![]() |
a4d1ee86dd | ||
![]() |
62e640e1f0 | ||
![]() |
8546cc7ddf | ||
![]() |
14768ffc0a | ||
![]() |
7ae3e96ec9 | ||
![]() |
594bbaa3dd | ||
![]() |
d4c25872c6 | ||
![]() |
3746c7d26e | ||
![]() |
df182d652a | ||
![]() |
21c3e75177 | ||
![]() |
382afa8642 | ||
![]() |
6b1efd0c49 | ||
![]() |
3fe4f49bca | ||
![]() |
eb3b9b8958 | ||
![]() |
1b7581e656 | ||
![]() |
246d825bb4 | ||
![]() |
49c0d8a621 | ||
![]() |
5f1f5833d8 | ||
![]() |
21dd572969 | ||
![]() |
5f6c658068 | ||
![]() |
e51d7ad0b9 | ||
![]() |
6c1a964f31 | ||
![]() |
b22a7ac427 | ||
![]() |
6c2d544f7c | ||
![]() |
5a7dad2363 | ||
![]() |
14d9a533cb | ||
![]() |
f7a0d360b3 | ||
![]() |
cf6eec639f | ||
![]() |
1da5bd95d6 | ||
![]() |
b40b534d5b | ||
![]() |
27d763f1fb | ||
![]() |
0ba109c083 | ||
![]() |
e156540e2c | ||
![]() |
0e2289344a | ||
![]() |
987d7b95f3 | ||
![]() |
7e231b5b50 | ||
![]() |
2077d800ae | ||
![]() |
b57d3d32fd | ||
![]() |
61ae8c2ee5 | ||
![]() |
800a2c8f6b | ||
![]() |
9abde030af | ||
![]() |
0320957b85 | ||
![]() |
32ebb729e8 | ||
![]() |
c9efe0a31b | ||
![]() |
5e5511caf8 | ||
![]() |
55bdd706b6 | ||
![]() |
74d618774a | ||
![]() |
54590eb147 | ||
![]() |
b1dc574f5b | ||
![]() |
e9f6620b60 | ||
![]() |
bcfe5807a7 | ||
![]() |
1e8ccd2b8b | ||
![]() |
f5ca3d1115 | ||
![]() |
c1f58f2675 | ||
![]() |
36377ee27d | ||
![]() |
b1472479cb | ||
![]() |
042188f15a | ||
![]() |
5074100b71 | ||
![]() |
4b5c36f6c0 | ||
![]() |
56ea007f2b | ||
![]() |
22d34d4e75 | ||
![]() |
9bf37486a0 | ||
![]() |
e176d05cf5 |
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/output/
|
84
README.md
@ -1,2 +1,84 @@
|
||||
# cristian_lenta - BA code
|
||||
# Bureaucratic Cohort Swarms
|
||||
### Pruning Networks by SRNN
|
||||
###### Deadline: 28.02.22
|
||||
|
||||
## Experimente
|
||||
|
||||
### Fixpoint Tests:
|
||||
|
||||
- [X] Dropout Test
|
||||
- (Macht das Partikel beim Goal mit oder ist es nur SRN)
|
||||
- Zero_ident diff = -00.04999637603759766 %
|
||||
|
||||
- [ ] gnf(1) -> Aprox. Weight
|
||||
- Übersetung in ein Gewichtsskalar
|
||||
- Einbettung in ein Reguläres Netz
|
||||
|
||||
- [ ] Übersetzung in ein Explainable AI Framework
|
||||
- Rückschlüsse auf Mikro Netze
|
||||
|
||||
- [ ] Visualiserung
|
||||
- Der Zugehörigkeit
|
||||
- Der Vernetzung
|
||||
|
||||
- [ ] PCA()
|
||||
- Dataframe Epoch, Weight, dim_1, ..., dim_n
|
||||
- Visualisierung als Trajectory Cube
|
||||
|
||||
- [ ] Recherche zu Makro Mikro Netze Strukturen
|
||||
- gits das schon?
|
||||
- Hypernetwork?
|
||||
- arxiv: 1905.02898
|
||||
- Sparse Networks
|
||||
- Pruning
|
||||
|
||||
---
|
||||
|
||||
### Tasks für Steffen:
|
||||
- [x] Sanity Check:
|
||||
|
||||
- [x] Neuronen können lernen einen Eingabewert mit x zu multiplizieren?
|
||||
|
||||
| SRNN x*n 3 Neurons Identity_Func | SRNN x*n 4 Neurons Identity_Func |
|
||||
|---------------------------------------------------|----------------------------------------------------|
|
||||
|  |  |
|
||||
| SRNN x*n 6 Neurons Other_Func | SRNN x*n 10 Neurons Other_Func |
|
||||
|  |  |
|
||||
|
||||
- [ ] Connectivity
|
||||
- Das Netz dünnt sich wirklich aus.
|
||||
|
||||
|||
|
||||
|---------------------------------------------------|----------------------------------------------------|
|
||||
| 200 Epochs - 4 Neurons - \alpha 100 RES | |
|
||||
|  |  |
|
||||
| OTHER FUNTIONS | IDENTITY FUNCTIONS |
|
||||
|  |  |
|
||||
|
||||
- [ ] Training mit kleineren GNs
|
||||
|
||||
|
||||
- [ ] Weiter Trainieren -> 500 Epochs?
|
||||
- [x] Training ohne Residual Skip Connection
|
||||
- Ist anders:
|
||||
Self Training wird zunächst priorisiert, dann kommt langsam der eigentliche Task durch:
|
||||
|
||||
| No Residual Skip connections 8 Neurons in SRNN Alpha=100 | Residual Skip connections 8 Neurons in SRNN Alpha=100 |
|
||||
|------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------|
|
||||
|  |  |
|
||||
|  |  |
|
||||
|
||||
- [ ] Test mit Baseline Dense Network
|
||||
- [ ] mit vergleichbaren Neuron Count
|
||||
- [ ] mit gesamt Weight Count
|
||||
|
||||
- [ ] Task/Goal statt SRNN-Task
|
||||
|
||||
---
|
||||
|
||||
### Für Menschen mit zu viel Zeit:
|
||||
- [ ] Sparse Network Training der Self Replication
|
||||
- Just for the lulz and speeeeeeed)
|
||||
- (Spaß bei Seite, wäre wichtig für schnellere Forschung)
|
||||
<https://pytorch.org/docs/stable/sparse.html>
|
||||
|
||||
|
0
experiments/__init__.py
Normal file
59
experiments/helpers.py
Normal file
@ -0,0 +1,59 @@
|
||||
""" -------------------------------- Methods for summarizing the experiments --------------------------------- """
|
||||
from pathlib import Path
|
||||
|
||||
from visualization import line_chart_fixpoints, bar_chart_fixpoints
|
||||
|
||||
|
||||
def summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory,
|
||||
summary_pre_title):
|
||||
avg_fixpoint_counters = {
|
||||
"avg_identity_func": 0,
|
||||
"avg_divergent": 0,
|
||||
"avg_fix_zero": 0,
|
||||
"avg_fix_weak": 0,
|
||||
"avg_fix_sec": 0,
|
||||
"avg_other_func": 0
|
||||
}
|
||||
|
||||
for i in range(len(experiments)):
|
||||
fixpoint_counters = experiments[i].fixpoint_counters
|
||||
|
||||
avg_fixpoint_counters["avg_identity_func"] += fixpoint_counters["identity_func"]
|
||||
avg_fixpoint_counters["avg_divergent"] += fixpoint_counters["divergent"]
|
||||
avg_fixpoint_counters["avg_fix_zero"] += fixpoint_counters["fix_zero"]
|
||||
avg_fixpoint_counters["avg_fix_weak"] += fixpoint_counters["fix_weak"]
|
||||
avg_fixpoint_counters["avg_fix_sec"] += fixpoint_counters["fix_sec"]
|
||||
avg_fixpoint_counters["avg_other_func"] += fixpoint_counters["other_func"]
|
||||
|
||||
# Calculating the average for each fixpoint
|
||||
avg_fixpoint_counters.update((x, y / len(experiments)) for x, y in avg_fixpoint_counters.items())
|
||||
|
||||
# Checking where the data is coming from to have a relevant title in the plot.
|
||||
if summary_pre_title not in ["ST", "SA", "soup", "mixed", "robustness"]:
|
||||
summary_pre_title = ""
|
||||
|
||||
# Plotting the summary
|
||||
source_checker = "summary"
|
||||
exp_details = f"{summary_pre_title}: {runs} runs & {epochs} epochs each."
|
||||
bar_chart_fixpoints(avg_fixpoint_counters, population_size, directory, net_learning_rate, exp_details,
|
||||
source_checker)
|
||||
|
||||
|
||||
def summary_fixpoint_percentage(runs, epochs, fixpoints_percentages, ST_steps, SA_steps, directory_name,
|
||||
population_size):
|
||||
fixpoints_percentages = [round(fixpoints_percentages[i] / runs, 1) for i in range(len(fixpoints_percentages))]
|
||||
|
||||
# Plotting summary
|
||||
if "soup" in directory_name:
|
||||
line_chart_fixpoints(fixpoints_percentages, epochs / ST_steps, ST_steps, SA_steps, directory_name,
|
||||
population_size)
|
||||
else:
|
||||
line_chart_fixpoints(fixpoints_percentages, epochs, ST_steps, SA_steps, directory_name, population_size)
|
||||
|
||||
|
||||
""" -------------------------------------------- Miscellaneous --------------------------------------------------- """
|
||||
|
||||
|
||||
def check_folder(experiment_folder: str):
|
||||
exp_path = Path('experiments') / experiment_folder
|
||||
exp_path.mkdir(parents=True, exist_ok=True)
|
38
experiments/meta_task_small_utility.py
Normal file
@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
class AddTaskDataset(Dataset):
|
||||
def __init__(self, length=int(1e3)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = torch.randn(size=(2,)).to(torch.float32)
|
||||
return ab, ab.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
btch_x, btch_y = btch_x.to(DEVICE), btch_y.to(DEVICE)
|
||||
y_prd = model(btch_x)
|
||||
|
||||
loss = loss_func(y_prd, btch_y.to(torch.float))
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
stp_log = dict(Metric='Task Loss', Score=loss.item())
|
||||
|
||||
return stp_log, y_prd
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise(NotImplementedError('Get out of here'))
|
405
experiments/meta_task_utility.py
Normal file
@ -0,0 +1,405 @@
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
||||
from sanity_check_weights import test_weights_as_model, extract_weights_from_model
|
||||
|
||||
WORKER = 10
|
||||
BATCHSIZE = 500
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
DATA_PATH = Path('data')
|
||||
DATA_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
PALETTE = sns.color_palette()
|
||||
PALETTE.insert(0, PALETTE.pop(1)) # Orange First
|
||||
|
||||
FINAL_CHECKPOINT_NAME = f'trained_model_ckpt_FINAL.tp'
|
||||
|
||||
|
||||
class AddGaussianNoise(object):
|
||||
def __init__(self, ratio=1e-4):
|
||||
self.ratio = ratio
|
||||
|
||||
def __call__(self, tensor: torch.Tensor):
|
||||
return tensor + (torch.randn_like(tensor, device=tensor.device) * self.ratio)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(ratio={self.ratio}'
|
||||
|
||||
|
||||
class ToFloat:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
||||
class AddTaskDataset(Dataset):
|
||||
def __init__(self, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.prng = np.random.default_rng()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = self.prng.normal(size=(2,)).astype(np.float32)
|
||||
return ab, ab.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
if not final_model:
|
||||
epoch_n = str(epoch_n)
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
if isinstance(epoch_n, str):
|
||||
ckpt_path = Path(out_path) / f'{Path(FINAL_CHECKPOINT_NAME).stem}_{epoch_n}.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
|
||||
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
py_store_path = Path(out_path) / 'exp_py.txt'
|
||||
if not py_store_path.exists():
|
||||
shutil.copy(__file__, py_store_path)
|
||||
return ckpt_path
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def validate(checkpoint_path, valid_loader, metric_class=torchmetrics.Accuracy):
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
|
||||
# initialize metric
|
||||
validmetric = metric_class()
|
||||
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
||||
|
||||
with tqdm(total=len(valid_loader), desc='Validation Run: ') as pbar:
|
||||
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_loader):
|
||||
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
|
||||
y_valid = model(valid_batch_x)
|
||||
|
||||
# metric on current batch
|
||||
measure = validmetric(y_valid.cpu(), valid_batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Measure: {measure}')
|
||||
pbar.update()
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
measure = validmetric.compute()
|
||||
tqdm.write(f"Avg. {validmetric._get_name()} on all data: {measure}")
|
||||
return measure
|
||||
|
||||
|
||||
def new_storage_df(identifier, weight_count):
|
||||
if identifier == 'train':
|
||||
return pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
elif identifier == 'weights':
|
||||
return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))])
|
||||
|
||||
|
||||
def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, keep_n=5, final_model=False,
|
||||
validation_metric=torchmetrics.Accuracy):
|
||||
out_path = Path(out_path)
|
||||
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
|
||||
# Clean up Checkpoints
|
||||
if keep_n > 0:
|
||||
all_ckpts = sorted(list(ckpt_path.parent.iterdir()))
|
||||
while len(all_ckpts) > keep_n:
|
||||
all_ckpts.pop(0).unlink()
|
||||
elif keep_n == 0:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'"keep_n" cannot be negative, but was: {keep_n}')
|
||||
|
||||
result = validate(ckpt_path, valid_loader, metric_class=validation_metric)
|
||||
return result
|
||||
|
||||
|
||||
def plot_training_particle_types(path_to_dataframe):
|
||||
plt.close('all')
|
||||
plt.clf()
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
|
||||
# Set up figure
|
||||
fig, ax = plt.subplots() # initializes figure and plots
|
||||
data = df.loc[df['Metric'].isin(ft.all_types())]
|
||||
fix_types = data['Metric'].unique()
|
||||
data = data.pivot(index='Epoch', columns='Metric', values='Score').reset_index().fillna(0)
|
||||
_ = plt.stackplot(data['Epoch'], *[data[fixtype] for fixtype in fix_types],
|
||||
labels=fix_types.tolist(), colors=PALETTE)
|
||||
|
||||
ax.set(ylabel='Particle Count', xlabel='Epoch')
|
||||
ax.yaxis.get_major_locator().set_params(integer=True)
|
||||
# ax.set_title('Particle Type Count')
|
||||
|
||||
fig.legend(loc="center right", title='Particle Type', bbox_to_anchor=(0.85, 0.5))
|
||||
plt.tight_layout()
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_particle_type_lp.png'), dpi=300)
|
||||
|
||||
|
||||
def plot_training_result(path_to_dataframe, metric_name='Accuracy', plot_name=None):
|
||||
plt.clf()
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=False).sort_values('Metric')
|
||||
|
||||
# Check if this is a single lineplot or if aggregated
|
||||
group = ['Epoch', 'Metric']
|
||||
if 'Seed' in df.columns:
|
||||
group.append('Seed')
|
||||
|
||||
# Set up figure
|
||||
fig, ax1 = plt.subplots() # initializes figure and plots
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
|
||||
|
||||
# plots the first set of data
|
||||
data = df[(df['Metric'] == 'Task Loss') | (df['Metric'] == 'Self Train Loss')].groupby(['Epoch', 'Metric']).mean()
|
||||
grouped_for_lineplot = data.groupby(group).mean()
|
||||
palette_len_1 = len(grouped_for_lineplot.droplevel(0).reset_index().Metric.unique())
|
||||
|
||||
sns.lineplot(data=grouped_for_lineplot, x='Epoch', y='Score', hue='Metric',
|
||||
palette=PALETTE[:palette_len_1], ax=ax1, ci='sd')
|
||||
|
||||
# plots the second set of data
|
||||
data = df[(df['Metric'] == f'Test {metric_name}') | (df['Metric'] == f'Train {metric_name}')]
|
||||
palette_len_2 = len(data.Metric.unique())
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', hue='Metric',
|
||||
palette=PALETTE[palette_len_1:palette_len_2+palette_len_1], ci='sd')
|
||||
|
||||
ax1.set(yscale='log', ylabel='Losses')
|
||||
# ax1.set_title('Training Lineplot')
|
||||
ax2.set(ylabel=metric_name)
|
||||
if metric_name != 'Accuracy':
|
||||
ax2.set(yscale='log')
|
||||
|
||||
fig.legend(loc="center right", title='Metric', bbox_to_anchor=(0.85, 0.5))
|
||||
for ax in [ax1, ax2]:
|
||||
if legend := ax.get_legend():
|
||||
legend.remove()
|
||||
plt.tight_layout()
|
||||
plt.savefig(Path(path_to_dataframe.parent / ('training_lineplot.png' if plot_name is None else plot_name)), dpi=300)
|
||||
|
||||
|
||||
def plot_network_connectivity_by_fixtype(path_to_trained_model):
|
||||
m = torch.load(path_to_trained_model, map_location=DEVICE).eval()
|
||||
# noinspection PyProtectedMember
|
||||
particles = list(m.particles)
|
||||
df = pd.DataFrame(columns=['Type', 'Layer', 'Neuron', 'Name'])
|
||||
|
||||
for prtcl in particles:
|
||||
l, c, w = [float(x) for x in re.sub("[^0-9|_]", "", prtcl.name).split('_')]
|
||||
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l-1, w, prtcl.name)
|
||||
df.loc[df.shape[0]] = (prtcl.is_fixpoint, l, c, prtcl.name)
|
||||
for layer in list(df['Layer'].unique()):
|
||||
# Rescale
|
||||
divisor = df.loc[(df['Layer'] == layer), 'Neuron'].max()
|
||||
df.loc[(df['Layer'] == layer), 'Neuron'] /= divisor
|
||||
|
||||
tqdm.write(f'Connectivity Data gathered')
|
||||
df = df.sort_values('Type')
|
||||
n = 0
|
||||
for fixtype in ft.all_types():
|
||||
if df[df['Type'] == fixtype].shape[0] > 0:
|
||||
plt.clf()
|
||||
ax = sns.lineplot(y='Neuron', x='Layer', hue='Name', data=df[df['Type'] == fixtype],
|
||||
legend=False, estimator=None, lw=1)
|
||||
_ = sns.lineplot(y=[0, 1], x=[-1, df['Layer'].max()], legend=False, estimator=None, lw=0)
|
||||
ax.set_title(fixtype)
|
||||
ax.yaxis.get_major_locator().set_params(integer=True)
|
||||
ax.xaxis.get_major_locator().set_params(integer=True)
|
||||
ax.set_ylabel('Normalized Neuron Position (1/n)') # XAXIS Label
|
||||
lines = ax.get_lines()
|
||||
for line in lines:
|
||||
line.set_color(PALETTE[n])
|
||||
plt.savefig(Path(path_to_trained_model.parent / f'net_connectivity_{fixtype}.png'), dpi=300)
|
||||
tqdm.write(f'Connectivity plottet: {fixtype} - n = {df[df["Type"] == fixtype].shape[0] // 2}')
|
||||
n += 1
|
||||
else:
|
||||
# tqdm.write(f'No Connectivity {fixtype}')
|
||||
pass
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def run_particle_dropout_test(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
|
||||
diff_store_path = model_path.parent / 'diff_store.csv'
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
prtcl_dict = defaultdict(lambda: 0)
|
||||
_ = test_for_fixpoints(prtcl_dict, list(latest_model.particles))
|
||||
tqdm.write(str(dict(prtcl_dict)))
|
||||
diff_df = pd.DataFrame(columns=['Particle Type', metric_class()._get_name(), 'Diff'])
|
||||
|
||||
acc_pre = validate(model_path, valid_loader, metric_class=metric_class).item()
|
||||
diff_df.loc[diff_df.shape[0]] = ('All Organism', acc_pre, 0)
|
||||
|
||||
for fixpoint_type in ft.all_types():
|
||||
new_model = torch.load(model_path, map_location=DEVICE).eval().replace_with_zero(fixpoint_type)
|
||||
if [x for x in new_model.particles if x.is_fixpoint == fixpoint_type]:
|
||||
new_ckpt = set_checkpoint(new_model, model_path.parent, fixpoint_type, final_model=True)
|
||||
acc_post = validate(new_ckpt, valid_loader, metric_class=metric_class).item()
|
||||
acc_diff = abs(acc_post - acc_pre)
|
||||
tqdm.write(f'Zero_ident diff = {acc_diff}')
|
||||
diff_df.loc[diff_df.shape[0]] = (fixpoint_type, acc_post, acc_diff)
|
||||
|
||||
diff_df.to_csv(diff_store_path, mode='w', header=True, index=False)
|
||||
return diff_store_path
|
||||
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
def plot_dropout_stacked_barplot(mdl_path, diff_store_path, metric_class=torchmetrics.Accuracy):
|
||||
metric_name = metric_class()._get_name()
|
||||
diff_df = pd.read_csv(diff_store_path).sort_values('Particle Type')
|
||||
particle_dict = defaultdict(lambda: 0)
|
||||
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||
_ = test_for_fixpoints(particle_dict, list(latest_model.particles))
|
||||
particle_dict = dict(particle_dict)
|
||||
sorted_particle_dict = dict(sorted(particle_dict.items()))
|
||||
tqdm.write(str(sorted_particle_dict))
|
||||
plt.clf()
|
||||
fig, ax = plt.subplots(ncols=2)
|
||||
colors = PALETTE.copy()
|
||||
colors.insert(0, colors.pop(-1))
|
||||
palette_len = len(diff_df['Particle Type'].unique())
|
||||
_ = sns.barplot(data=diff_df, y=metric_name, x='Particle Type', ax=ax[0], palette=colors[:palette_len], ci=None)
|
||||
|
||||
ax[0].set_title(f'{metric_name} after particle dropout')
|
||||
# ax[0].set_xlabel('Particle Type') # XAXIS Label
|
||||
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=30)
|
||||
|
||||
ax[1].pie(sorted_particle_dict.values(), labels=sorted_particle_dict.keys(),
|
||||
colors=PALETTE[:len(sorted_particle_dict)])
|
||||
ax[1].set_title('Particle Count')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(Path(diff_store_path.parent / 'dropout_stacked_barplot.png'), dpi=300)
|
||||
|
||||
|
||||
def run_particle_dropout_and_plot(model_path, valid_loader, metric_class=torchmetrics.Accuracy):
|
||||
diff_store_path = run_particle_dropout_test(model_path, valid_loader=valid_loader, metric_class=metric_class)
|
||||
plot_dropout_stacked_barplot(model_path, diff_store_path, metric_class=metric_class)
|
||||
|
||||
|
||||
def flat_for_store(parameters):
|
||||
return (x.item() for y in parameters for x in y.detach().flatten())
|
||||
|
||||
|
||||
def train_self_replication(model, st_stps, **kwargs) -> dict:
|
||||
self_train_loss = model.combined_self_train(st_stps, **kwargs)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
stp_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
return stp_log
|
||||
|
||||
|
||||
def train_task(model, optimizer, loss_func, btch_x, btch_y) -> (dict, torch.Tensor):
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
btch_x, btch_y = btch_x.to(DEVICE), btch_y.to(DEVICE)
|
||||
y_prd = model(btch_x)
|
||||
|
||||
loss = loss_func(y_prd, btch_y.to(torch.long))
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
stp_log = dict(Metric='Task Loss', Score=loss.item())
|
||||
|
||||
return stp_log, y_prd
|
||||
|
||||
|
||||
def highlight_fixpoints_vs_mnist_mean(mdl_path, dataloader):
|
||||
latest_model = torch.load(mdl_path, map_location=DEVICE).eval()
|
||||
activation_vector = torch.as_tensor([[0, 0, 0, 0, 1]], dtype=torch.float32, device=DEVICE)
|
||||
binary_images = []
|
||||
real_images = []
|
||||
with torch.no_grad():
|
||||
# noinspection PyProtectedMember
|
||||
for cell in latest_model._meta_layer_first.meta_cell_list:
|
||||
cell_image_binary = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
|
||||
cell_image_real = torch.zeros((len(cell.meta_weight_list)), device=DEVICE)
|
||||
for idx, particle in enumerate(cell.particles):
|
||||
if particle.is_fixpoint == ft.identity_func:
|
||||
cell_image_binary[idx] += 1
|
||||
cell_image_real[idx] = particle(activation_vector).abs().squeeze().item()
|
||||
binary_images.append(cell_image_binary.reshape((15, 15)))
|
||||
real_images.append(cell_image_real.reshape((15, 15)))
|
||||
|
||||
binary_images = torch.stack(binary_images)
|
||||
real_images = torch.stack(real_images)
|
||||
|
||||
binary_image = torch.sum(binary_images, keepdim=True, dim=0)
|
||||
real_image = torch.sum(real_images, keepdim=True, dim=0)
|
||||
|
||||
mnist_images = [x for x, _ in dataloader]
|
||||
mnist_mean = torch.cat(mnist_images).reshape(10000, 15, 15).abs().sum(dim=0)
|
||||
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
|
||||
for idx, (image, title) in enumerate(zip([binary_image, real_image, mnist_mean],
|
||||
["Particle Count", "Particle Value", "MNIST mean"])):
|
||||
img = axs[idx].imshow(image.squeeze().detach().cpu())
|
||||
img.axes.axis('off')
|
||||
img.axes.set_title('Random Noise')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(mdl_path.parent / 'heatmap.png', dpi=300)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def plot_training_results_over_n_seeds(exp_path, df_train_store_name='train_store.csv', metric_name='Accuracy'):
|
||||
combined_df_store_path = exp_path / f'comb_train_{exp_path.stem[:-1]}n.csv'
|
||||
# noinspection PyUnboundLocalVariable
|
||||
found_train_stores = exp_path.rglob(df_train_store_name)
|
||||
train_dfs = []
|
||||
for found_train_store in found_train_stores:
|
||||
train_store_df = pd.read_csv(found_train_store, index_col=False)
|
||||
train_store_df['Seed'] = int(found_train_store.parent.name)
|
||||
train_dfs.append(train_store_df)
|
||||
combined_train_df = pd.concat(train_dfs)
|
||||
combined_train_df.to_csv(combined_df_store_path, index=False)
|
||||
plot_training_result(combined_df_store_path, metric_name=metric_name,
|
||||
plot_name=f"{combined_df_store_path.stem}.png"
|
||||
)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def sanity_weight_swap(exp_path, dataloader, metric_class=torchmetrics.Accuracy):
|
||||
# noinspection PyProtectedMember
|
||||
metric_name = metric_class()._get_name()
|
||||
found_models = exp_path.rglob(f'*{FINAL_CHECKPOINT_NAME}')
|
||||
df = pd.DataFrame(columns=['Seed', 'Model', metric_name])
|
||||
for model_idx, found_model in enumerate(found_models):
|
||||
model = torch.load(found_model, map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
|
||||
results = test_weights_as_model(model, weights, dataloader, metric_class=metric_class)
|
||||
for model_name, measurement in results.items():
|
||||
df.loc[df.shape[0]] = (model_idx, model_name, measurement)
|
||||
df.loc[df.shape[0]] = (model_idx, 'Difference', np.abs(np.subtract(*results.values())))
|
||||
|
||||
df.to_csv(exp_path / 'sanity_weight_swap.csv', index=False)
|
||||
_ = sns.boxplot(data=df, x='Model', y=metric_name)
|
||||
plt.tight_layout()
|
||||
plt.savefig(exp_path / 'sanity_weight_swap.png', dpi=300)
|
||||
plt.clf()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise NotImplementedError('Test this here!!!')
|
118
experiments/robustness_tester.py
Normal file
@ -0,0 +1,118 @@
|
||||
import pandas as pd
|
||||
import torch
|
||||
import random
|
||||
import copy
|
||||
|
||||
from tqdm import tqdm
|
||||
from functionalities_test import (is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent,
|
||||
FixTypes as FT)
|
||||
from network import Net
|
||||
from torch.nn import functional as F
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
def generate_perfekt_synthetic_fixpoint_weights():
|
||||
return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0],
|
||||
[1.0], [0.0], [0.0], [0.0],
|
||||
[1.0], [0.0]
|
||||
], dtype=torch.float32)
|
||||
|
||||
PALETTE = 10 * (
|
||||
"#377eb8",
|
||||
"#4daf4a",
|
||||
"#984ea3",
|
||||
"#e41a1c",
|
||||
"#ff7f00",
|
||||
"#a65628",
|
||||
"#f781bf",
|
||||
"#888888",
|
||||
"#a6cee3",
|
||||
"#b2df8a",
|
||||
"#cab2d6",
|
||||
"#fb9a99",
|
||||
"#fdbf6f",
|
||||
)
|
||||
|
||||
|
||||
def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||
model = torch.load(model_path, map_location='cpu')
|
||||
networks = [x for x in model.particles if x.is_fixpoint == FT.identity_func]
|
||||
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||
row_headers = []
|
||||
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
|
||||
'Time to convergence', 'Time as fixpoint'])
|
||||
with tqdm(total=(seeds * noise_levels * len(networks)), desc='Per Particle Robustness') as pbar:
|
||||
for setting, fixpoint in enumerate(networks): # 1 / n
|
||||
row_headers.append(fixpoint.name)
|
||||
for seed in range(seeds): # n / 1
|
||||
for noise_level in range(noise_levels):
|
||||
steps = 0
|
||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||
f"{fixpoint.name}_clone_noise_1e-{noise_level}")
|
||||
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||
clone = clone.apply_noise(pow(10, -noise_level))
|
||||
|
||||
while not is_zero_fixpoint(clone) and not is_divergent(clone):
|
||||
# -> before
|
||||
clone_weight_pre_application = clone.input_weight_matrix()
|
||||
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
|
||||
|
||||
clone.self_application(1, log_step_size)
|
||||
time_to_vergence[setting][noise_level] += 1
|
||||
# -> after
|
||||
clone_weight_post_application = clone.input_weight_matrix()
|
||||
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
||||
|
||||
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
|
||||
|
||||
if is_identity_function(clone):
|
||||
time_as_fixpoint[setting][noise_level] += 1
|
||||
# When this raises a Type Error, we found a second order fixpoint!
|
||||
steps += 1
|
||||
|
||||
df.loc[df.shape[0]] = [f'{setting}_{seed}', fr'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||
steps, absolute_loss,
|
||||
time_to_vergence[setting][noise_level],
|
||||
time_as_fixpoint[setting][noise_level]]
|
||||
pbar.update(1)
|
||||
|
||||
# Get the measuremts at the highest time_time_to_vergence
|
||||
df_sorted = df.sort_values('Self Train Steps', ascending=False).drop_duplicates(['setting', 'Noise Level'])
|
||||
df_melted = df_sorted.reset_index().melt(id_vars=['setting', 'Noise Level', 'Self Train Steps'],
|
||||
value_vars=['Time to convergence', 'Time as fixpoint'],
|
||||
var_name="Measurement",
|
||||
value_name="Steps").sort_values('Noise Level')
|
||||
|
||||
df_melted.to_csv(model_path.parent / 'robustness_boxplot.csv', index=False)
|
||||
|
||||
# Plotting
|
||||
# plt.rcParams.update({
|
||||
# "text.usetex": True,
|
||||
# "font.family": "sans-serif",
|
||||
# "font.size": 12,
|
||||
# "font.weight": 'bold',
|
||||
# "font.sans-serif": ["Helvetica"]})
|
||||
plt.clf()
|
||||
sns.set(style='whitegrid', font_scale=1)
|
||||
_ = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
|
||||
plt.tight_layout()
|
||||
|
||||
# sns.set(rc={'figure.figsize': (10, 50)})
|
||||
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
|
||||
# col='noise_level', col_wrap=3, showfliers=False)
|
||||
|
||||
filename = f"robustness_boxplot.png"
|
||||
filepath = model_path.parent / filename
|
||||
plt.savefig(str(filepath))
|
||||
plt.close('all')
|
||||
return time_as_fixpoint, time_to_vergence
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise NotImplementedError('Get out of here!')
|
BIN
figures/connectivity/identity.png
Normal file
After Width: | Height: | Size: 98 KiB |
BIN
figures/connectivity/other.png
Normal file
After Width: | Height: | Size: 97 KiB |
BIN
figures/connectivity/training_lineplot.png
Normal file
After Width: | Height: | Size: 198 KiB |
BIN
figures/connectivity/training_particle_type_lp.png
Normal file
After Width: | Height: | Size: 91 KiB |
After Width: | Height: | Size: 187 KiB |
After Width: | Height: | Size: 93 KiB |
BIN
figures/res_no_res/mn_st_200_8_alpha_100_training_lineplot.png
Normal file
After Width: | Height: | Size: 176 KiB |
After Width: | Height: | Size: 94 KiB |
BIN
figures/sanity/sanity_10hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
figures/sanity/sanity_2hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 18 KiB |
BIN
figures/sanity/sanity_3hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 22 KiB |
BIN
figures/sanity/sanity_4hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 24 KiB |
BIN
figures/sanity/sanity_6hidden_xtimesn.png
Normal file
After Width: | Height: | Size: 23 KiB |
97
functionalities_test.py
Normal file
@ -0,0 +1,97 @@
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from network import FixTypes, Net
|
||||
|
||||
|
||||
epsilon_error_margin = pow(10, -5)
|
||||
|
||||
|
||||
def is_divergent(network: Net) -> bool:
|
||||
return network.input_weight_matrix().isinf().any().item() or network.input_weight_matrix().isnan().any().item()
|
||||
|
||||
|
||||
def is_identity_function(network: Net, epsilon=epsilon_error_margin) -> bool:
|
||||
|
||||
input_data = network.input_weight_matrix()
|
||||
target_data = network.create_target_weights(input_data)
|
||||
predicted_values = network(input_data)
|
||||
|
||||
return torch.allclose(target_data.detach(), predicted_values.detach(),
|
||||
rtol=0, atol=epsilon)
|
||||
|
||||
|
||||
def is_zero_fixpoint(network: Net, epsilon=epsilon_error_margin) -> bool:
|
||||
target_data = network.create_target_weights(network.input_weight_matrix().detach())
|
||||
result = torch.allclose(target_data, torch.zeros_like(target_data), rtol=0, atol=epsilon)
|
||||
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
|
||||
return result
|
||||
|
||||
|
||||
def is_secondary_fixpoint(network: Net, epsilon: float = epsilon_error_margin) -> bool:
|
||||
""" Secondary fixpoint check is done like this: compare first INPUT with second OUTPUT.
|
||||
If they are within the boundaries, then is secondary fixpoint. """
|
||||
|
||||
input_data = network.input_weight_matrix()
|
||||
target_data = network.create_target_weights(input_data)
|
||||
|
||||
# Calculating first output
|
||||
first_output = network(input_data)
|
||||
|
||||
# Getting the second output by initializing a new net with the weights of the original net.
|
||||
net_copy = copy.deepcopy(network)
|
||||
net_copy.apply_weights(first_output)
|
||||
input_data_2 = net_copy.input_weight_matrix()
|
||||
|
||||
# Calculating second output
|
||||
second_output = network(input_data_2)
|
||||
|
||||
# Perform the Check: all(epsilon > abs(input_data - second_output))
|
||||
check_abs_within_epsilon = torch.allclose(target_data.detach(), second_output.detach(),
|
||||
rtol=0, atol=epsilon)
|
||||
return check_abs_within_epsilon
|
||||
|
||||
|
||||
def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
|
||||
id_functions = id_functions or list()
|
||||
|
||||
for net in tqdm(nets, desc='Fixpoint Tester', total=len(nets)):
|
||||
if is_divergent(net):
|
||||
fixpoint_counter[FixTypes.divergent] += 1
|
||||
net.is_fixpoint = FixTypes.divergent
|
||||
elif is_zero_fixpoint(net):
|
||||
fixpoint_counter[FixTypes.fix_zero] += 1
|
||||
net.is_fixpoint = FixTypes.fix_zero
|
||||
elif is_identity_function(net): # is default value
|
||||
fixpoint_counter[FixTypes.identity_func] += 1
|
||||
net.is_fixpoint = FixTypes.identity_func
|
||||
id_functions.append(net)
|
||||
elif is_secondary_fixpoint(net):
|
||||
fixpoint_counter[FixTypes.fix_sec] += 1
|
||||
net.is_fixpoint = FixTypes.fix_sec
|
||||
else:
|
||||
fixpoint_counter[FixTypes.other_func] += 1
|
||||
net.is_fixpoint = FixTypes.other_func
|
||||
return id_functions
|
||||
|
||||
|
||||
def changing_rate(x_new, x_old):
|
||||
return x_new - x_old
|
||||
|
||||
|
||||
def test_status(net: Net) -> Net:
|
||||
|
||||
if is_divergent(net):
|
||||
net.is_fixpoint = FixTypes.divergent
|
||||
elif is_identity_function(net): # is default value
|
||||
net.is_fixpoint = FixTypes.identity_func
|
||||
elif is_zero_fixpoint(net):
|
||||
net.is_fixpoint = FixTypes.fix_zero
|
||||
elif is_secondary_fixpoint(net):
|
||||
net.is_fixpoint = FixTypes.fix_sec
|
||||
else:
|
||||
net.is_fixpoint = FixTypes.other_func
|
||||
|
||||
return net
|
273
meta_task_exp.py
Normal file
@ -0,0 +1,273 @@
|
||||
# # # Imports
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import platform
|
||||
|
||||
import torchmetrics
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
# noinspection DuplicatedCode
|
||||
from experiments.meta_task_utility import (ToFloat, new_storage_df, train_task, checkpoint_and_validate, flat_for_store,
|
||||
plot_training_result, plot_training_particle_types,
|
||||
plot_network_connectivity_by_fixtype, run_particle_dropout_and_plot,
|
||||
highlight_fixpoints_vs_mnist_mean, AddGaussianNoise,
|
||||
plot_training_results_over_n_seeds, sanity_weight_swap,
|
||||
FINAL_CHECKPOINT_NAME)
|
||||
from experiments.robustness_tester import test_robustness
|
||||
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
|
||||
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
|
||||
|
||||
from network import MetaNet, FixTypes
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) # , AddGaussianNoise()])
|
||||
WORKER = 10 if not debug else 2
|
||||
debug = False
|
||||
BATCHSIZE = 2000 if not debug else 50
|
||||
EPOCH = 200
|
||||
VALIDATION_FRQ = 3 if not debug else 1
|
||||
VAL_METRIC_CLASS = torchmetrics.Accuracy
|
||||
# noinspection PyProtectedMember
|
||||
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
DATA_PATH = Path('data')
|
||||
DATA_PATH.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
try:
|
||||
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
plot_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False, download=True)
|
||||
plot_loader = DataLoader(plot_dataset, batch_size=BATCHSIZE, shuffle=False,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
training = True
|
||||
plotting = True
|
||||
robustnes = False
|
||||
n_st = 300 # per batch !!
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
train_to_task_first = True
|
||||
min_task_acc = 0.85
|
||||
|
||||
residual_skip = True
|
||||
add_gauss = False
|
||||
|
||||
alpha_st_modulator = 0
|
||||
|
||||
for weight_hidden_size in [5]:
|
||||
|
||||
weight_hidden_size = weight_hidden_size
|
||||
n_seeds = 3
|
||||
depth = 3
|
||||
width = 5
|
||||
out = 10
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
|
||||
a_str = f'_aStM_{alpha_st_modulator}' if alpha_st_modulator not in [1, 0] else ''
|
||||
res_str = '_no_res' if not residual_skip else ''
|
||||
st_str = f'_nst_{n_st}'
|
||||
tsk_str = f'_tsktr_{min_task_acc}' if train_to_task_first else ''
|
||||
w_str = f'_w{width}wh{weight_hidden_size}d{depth}'
|
||||
|
||||
config_str = f'{res_str}{ac_str}{st_str}{tsk_str}{a_str}{w_str}'
|
||||
exp_path = Path('output') / f'mn_st_{EPOCH}{config_str}'
|
||||
last_accuracy = 0
|
||||
|
||||
for seed in range(0, n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
|
||||
df_store_path = seed_path / 'train_store.csv'
|
||||
weight_store_path = seed_path / 'weight_store.csv'
|
||||
srnn_parameters = dict()
|
||||
|
||||
if training:
|
||||
# Check if files do exist on project location, warn and break.
|
||||
for path in [df_store_path, weight_store_path]:
|
||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||
try:
|
||||
train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms)
|
||||
except RuntimeError:
|
||||
train_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, download=True)
|
||||
train_loader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
try:
|
||||
valid_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
valid_dataset = MNIST(str(DATA_PATH), transform=utility_transforms, train=False, download=True)
|
||||
valid_loader = DataLoader(valid_dataset, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(train_dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=depth, width=width, out=out,
|
||||
residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,
|
||||
activation=activation
|
||||
).to(DEVICE)
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
|
||||
metanet = metanet.train()
|
||||
|
||||
# Init metrics, even we do not need:
|
||||
metric = VAL_METRIC_CLASS()
|
||||
n_st_per_batch = max(n_st // len(train_loader), 1)
|
||||
|
||||
if is_validation_epoch:
|
||||
for particle in metanet.particles:
|
||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
|
||||
do_self_train = not train_to_task_first or last_accuracy >= min_task_acc
|
||||
train_to_task_first = train_to_task_first if not do_self_train else False
|
||||
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(train_loader),
|
||||
total=len(train_loader), desc='MetaNet Train - Batch'
|
||||
):
|
||||
# Self Train
|
||||
if do_self_train:
|
||||
self_train_loss = metanet.combined_self_train(n_st_per_batch, alpha=alpha_st_modulator,
|
||||
reduction='mean', per_particle=False)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
st_step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Task Train
|
||||
tsk_step_log, y_pred = train_task(metanet, optimizer, loss_fn, batch_x, batch_y)
|
||||
tsk_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = tsk_step_log
|
||||
metric(y_pred.cpu(), batch_y.cpu())
|
||||
|
||||
last_accuracy = metric.compute().item()
|
||||
if is_validation_epoch:
|
||||
metanet = metanet.eval()
|
||||
try:
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Train {VAL_METRIC_NAME}', Score=last_accuracy)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch, keep_n=5).item()
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
if is_validation_epoch:
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
counter_dict = dict(counter_dict)
|
||||
for key, value in counter_dict.items():
|
||||
val_step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = val_step_log
|
||||
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
|
||||
|
||||
# FLUSH to disk
|
||||
if is_validation_epoch:
|
||||
train_store.to_csv(df_store_path, mode='a',
|
||||
header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a',
|
||||
header=not weight_store_path.exists(), index=False)
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
###########################################################
|
||||
# EPOCHS endet
|
||||
metanet = metanet.eval()
|
||||
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
for key, value in dict(counter_dict).items():
|
||||
step_log = dict(Epoch=int(EPOCH)+1, Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, EPOCH, final_model=True)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
for particle in metanet.particles:
|
||||
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
|
||||
# FLUSH to disk
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
||||
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'{FINAL_CHECKPOINT_NAME}'))
|
||||
except StopIteration:
|
||||
print('Model pattern did not trigger.')
|
||||
print(f'Search path was: {seed_path}:')
|
||||
print(f'Found Models are: {list(seed_path.rglob(".tp"))}')
|
||||
exit(1)
|
||||
|
||||
if plotting:
|
||||
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
|
||||
|
||||
plot_training_result(df_store_path)
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
try:
|
||||
plot_network_connectivity_by_fixtype(model_path)
|
||||
except (ValueError, NameError) as e:
|
||||
print(e)
|
||||
|
||||
highlight_fixpoints_vs_mnist_mean(model_path, plot_loader)
|
||||
|
||||
try:
|
||||
for fixtype in FixTypes.all_types():
|
||||
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=fixtype)
|
||||
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
if robustnes:
|
||||
try:
|
||||
test_robustness(model_path, seeds=1)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
|
||||
if 2 <= n_seeds <= sum(list(x.is_dir() for x in exp_path.iterdir())):
|
||||
if plotting:
|
||||
|
||||
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
|
||||
|
||||
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)
|
244
meta_task_exp_small.py
Normal file
@ -0,0 +1,244 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchmetrics
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from experiments.meta_task_small_utility import AddTaskDataset, train_task
|
||||
from experiments.robustness_tester import test_robustness
|
||||
from network import MetaNet
|
||||
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
||||
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
|
||||
plot_training_particle_types, run_particle_dropout_and_plot, plot_network_connectivity_by_fixtype, \
|
||||
checkpoint_and_validate, plot_training_results_over_n_seeds, sanity_weight_swap, FINAL_CHECKPOINT_NAME
|
||||
from plot_3d_trajectories import plot_single_3d_trajectories_by_layer, plot_grouped_3d_trajectories_by_layer
|
||||
|
||||
WORKER = 0
|
||||
BATCHSIZE = 50
|
||||
EPOCH = 60
|
||||
VALIDATION_FRQ = 3
|
||||
VAL_METRIC_CLASS = torchmetrics.MeanAbsoluteError
|
||||
# noinspection PyProtectedMember
|
||||
VAL_METRIC_NAME = VAL_METRIC_CLASS()._get_name()
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
plot_loader = DataLoader(AddTaskDataset(), batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
training = False
|
||||
plotting = True
|
||||
robustness = True
|
||||
attack = False
|
||||
attack_ratio = 0.01
|
||||
melt = False
|
||||
melt_ratio = 0.01
|
||||
n_st = 200
|
||||
activation = None # nn.ReLU()
|
||||
|
||||
for weight_hidden_size in [3]:
|
||||
|
||||
weight_hidden_size = weight_hidden_size
|
||||
residual_skip = True
|
||||
n_seeds = 10
|
||||
depth = 3
|
||||
width = 3
|
||||
out = 1
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
ac_str = f'_{activation.__class__.__name__}' if activation is not None else ''
|
||||
res_str = f'{"" if residual_skip else "_no_res"}'
|
||||
att_str = f'_att_{attack_ratio}' if attack else ''
|
||||
mlt_str = f'_mlt_{melt_ratio}' if melt else ''
|
||||
w_str = f'_w{width}wh{weight_hidden_size}d{depth}'
|
||||
# dr_str = f'{f"_dr_{dropout}" if dropout != 0 else ""}'
|
||||
|
||||
config_str = f'{res_str}{att_str}{ac_str}{mlt_str}{w_str}'
|
||||
exp_path = Path('output') / f'add_st_{EPOCH}{config_str}'
|
||||
|
||||
# if not training:
|
||||
# # noinspection PyRedeclaration
|
||||
# exp_path = Path('output') / f'add_st_{n_st}_{weight_hidden_size}'
|
||||
|
||||
for seed in range(n_seeds):
|
||||
seed_path = exp_path / str(seed)
|
||||
|
||||
df_store_path = seed_path / 'train_store.csv'
|
||||
weight_store_path = seed_path / 'weight_store.csv'
|
||||
srnn_parameters = dict()
|
||||
|
||||
valid_data = AddTaskDataset()
|
||||
vali_load = DataLoader(valid_data, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
if training:
|
||||
# Check if files do exist on project location, warn and break.
|
||||
for path in [df_store_path, weight_store_path]:
|
||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||
|
||||
train_data = AddTaskDataset()
|
||||
train_load = DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True,
|
||||
drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(train_data[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=depth, width=width, out=out,
|
||||
residual_skip=residual_skip, weight_hidden_size=weight_hidden_size,
|
||||
activation=activation
|
||||
).to(DEVICE)
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
for epoch in tqdm(range(EPOCH), desc=f'Train - Epochs'):
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0
|
||||
metanet = metanet.train()
|
||||
|
||||
# Init metrics, even we do not need:
|
||||
metric = VAL_METRIC_CLASS()
|
||||
n_st_per_batch = max(1, (n_st // len(train_load)))
|
||||
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(train_load),
|
||||
total=len(train_load), desc='MetaNet Train - Batch'
|
||||
):
|
||||
# Self Train
|
||||
self_train_loss = metanet.combined_self_train(n_st_per_batch,
|
||||
reduction='mean', per_particle=False)
|
||||
# noinspection PyUnboundLocalVariable
|
||||
st_step_log = dict(Metric='Self Train Loss', Score=self_train_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Attack
|
||||
if attack:
|
||||
after_attack_loss = metanet.make_particles_attack(attack_ratio)
|
||||
st_step_log = dict(Metric='After Attack Loss', Score=after_attack_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Melt
|
||||
if melt:
|
||||
after_melt_loss = metanet.make_particles_melt(melt_ratio)
|
||||
st_step_log = dict(Metric='After Melt Loss', Score=after_melt_loss.item())
|
||||
st_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = st_step_log
|
||||
|
||||
# Task Train
|
||||
tsk_step_log, y_pred = train_task(metanet, optimizer, loss_fn, batch_x, batch_y)
|
||||
tsk_step_log.update(dict(Epoch=epoch, Batch=batch))
|
||||
train_store.loc[train_store.shape[0]] = tsk_step_log
|
||||
metric(y_pred.cpu(), batch_y.cpu())
|
||||
|
||||
if is_validation_epoch:
|
||||
metanet = metanet.eval()
|
||||
if metric.total.item():
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Train {VAL_METRIC_NAME}', Score=metric.compute().item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
mae = checkpoint_and_validate(metanet, vali_load, seed_path, epoch,
|
||||
validation_metric=VAL_METRIC_CLASS).item()
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=mae)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
if is_validation_epoch:
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
counter_dict = dict(counter_dict)
|
||||
for key, value in counter_dict.items():
|
||||
val_step_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = val_step_log
|
||||
tqdm.write(f'Fixpoint Tester Results: {counter_dict}')
|
||||
|
||||
# FLUSH to disk
|
||||
if is_validation_epoch:
|
||||
for particle in metanet.particles:
|
||||
weight_log = (epoch, particle.name, *flat_for_store(particle.parameters()))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
train_store.to_csv(df_store_path, mode='a',
|
||||
header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a',
|
||||
header=not weight_store_path.exists(), index=False)
|
||||
train_store = new_storage_df('train', None)
|
||||
weight_store = new_storage_df('weights', metanet.particle_parameter_count)
|
||||
|
||||
###########################################################
|
||||
# EPOCHS endet
|
||||
metanet = metanet.eval()
|
||||
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
# This returns ID-functions
|
||||
_ = test_for_fixpoints(counter_dict, list(metanet.particles))
|
||||
for key, value in dict(counter_dict).items():
|
||||
step_log = dict(Epoch=int(EPOCH), Batch=BATCHSIZE, Metric=key, Score=value)
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
accuracy = checkpoint_and_validate(metanet, vali_load, seed_path, EPOCH, final_model=True,
|
||||
validation_metric=VAL_METRIC_CLASS)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy.item())
|
||||
for particle in metanet.particles:
|
||||
weight_log = (EPOCH, particle.name, *(flat_for_store(particle.parameters())))
|
||||
weight_store.loc[weight_store.shape[0]] = weight_log
|
||||
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
train_store.to_csv(df_store_path, mode='a', header=not df_store_path.exists(), index=False)
|
||||
weight_store.to_csv(weight_store_path, mode='a', header=not weight_store_path.exists(), index=False)
|
||||
if plotting:
|
||||
|
||||
plot_training_result(df_store_path, metric_name=VAL_METRIC_NAME)
|
||||
plot_training_particle_types(df_store_path)
|
||||
|
||||
try:
|
||||
model_path = next(seed_path.glob(f'*{FINAL_CHECKPOINT_NAME}'))
|
||||
except StopIteration:
|
||||
print('####################################################')
|
||||
print('ERROR: Model pattern did not trigger.')
|
||||
print(f'INFO: Search path was: {seed_path}:')
|
||||
print(f'INFO: Found Models are: {list(seed_path.rglob(".tp"))}')
|
||||
print('####################################################')
|
||||
exit(1)
|
||||
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
run_particle_dropout_and_plot(model_path, valid_loader=plot_loader, metric_class=VAL_METRIC_CLASS)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
try:
|
||||
plot_network_connectivity_by_fixtype(model_path)
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
try:
|
||||
tqdm.write('Trajectory plotting ...')
|
||||
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
|
||||
plot_single_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.identity_func)
|
||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
|
||||
tqdm.write('Trajectory plotting Done')
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
if robustness:
|
||||
try:
|
||||
test_robustness(model_path, seeds=10)
|
||||
pass
|
||||
except ValueError as e:
|
||||
print('ERROR:', e)
|
||||
|
||||
if 2 <= n_seeds == sum(list(x.is_dir() for x in exp_path.iterdir())):
|
||||
if plotting:
|
||||
|
||||
plot_training_results_over_n_seeds(exp_path, metric_name=VAL_METRIC_NAME)
|
||||
|
||||
sanity_weight_swap(exp_path, plot_loader, VAL_METRIC_CLASS)
|
105
meta_task_sanity_exp.py
Normal file
@ -0,0 +1,105 @@
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import platform
|
||||
|
||||
import pandas as pd
|
||||
import torch.optim
|
||||
from matplotlib import pyplot as plt
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from tqdm import trange, tqdm
|
||||
from tqdm.contrib import tenumerate
|
||||
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
|
||||
import functionalities_test
|
||||
from network import Net
|
||||
|
||||
|
||||
class MultiplyByXTaskDataset(Dataset):
|
||||
def __init__(self, x=0.23, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
self.x = x
|
||||
self.prng = np.random.default_rng()
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, _):
|
||||
ab = self.prng.normal(size=(1,)).astype(np.float32)
|
||||
return ab, ab * self.x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = Net(5, 4, 1, lr=0.004)
|
||||
multiplication_target = 0.03
|
||||
st_steps = 0
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_frame = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
|
||||
dataset = MultiplyByXTaskDataset(x=multiplication_target, length=1000000)
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=8000, num_workers=0)
|
||||
for epoch in trange(30):
|
||||
mean_batch_loss = []
|
||||
mean_self_tain_loss = []
|
||||
|
||||
for batch, (batch_x, batch_y) in tenumerate(dataloader):
|
||||
self_train_loss, _ = net.self_train(1000 // 20, save_history=False)
|
||||
is_fixpoint = functionalities_test.is_identity_function(net)
|
||||
if not is_fixpoint:
|
||||
st_steps += 2
|
||||
|
||||
if is_fixpoint:
|
||||
tqdm.write(f'is fixpoint after st : {is_fixpoint}, first reached after st_steps: {st_steps}')
|
||||
tqdm.write(f'is fixpoint after tsk: {functionalities_test.is_identity_function(net)}')
|
||||
|
||||
#mean_batch_loss.append(loss.detach())
|
||||
mean_self_tain_loss.append(self_train_loss.detach())
|
||||
|
||||
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='Self Train Loss', Score=np.average(mean_self_tain_loss))
|
||||
train_frame.loc[train_frame.shape[0]] = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='Batch Loss', Score=np.average(mean_batch_loss))
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
functionalities_test.test_for_fixpoints(counter, nets=[net])
|
||||
print(dict(counter), self_train_loss)
|
||||
sanity = net(torch.Tensor([0,0,0,0,1])).detach()
|
||||
print(sanity)
|
||||
print(abs(sanity - multiplication_target))
|
||||
sns.lineplot(data=train_frame, x='Epoch', y='Score', hue='Metric')
|
||||
outpath = Path('output') / 'sanity' / 'test.png'
|
||||
outpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
plt.savefig(outpath)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
85
minimal_net_search.py
Normal file
@ -0,0 +1,85 @@
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST, CIFAR10
|
||||
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
|
||||
import torchmetrics
|
||||
import pickle
|
||||
|
||||
from network import MetaNetCompareBaseline
|
||||
|
||||
WORKER = 0
|
||||
BATCHSIZE = 500
|
||||
EPOCH = 10
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
MNIST_TRANSFORM = Compose([ Resize((10, 10)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)])
|
||||
CIFAR10_TRANSFORM = Compose([ Grayscale(num_output_channels=1), Resize((10, 10)), ToTensor(), Normalize((0.48,), (0.25,)), Flatten(start_dim=0)])
|
||||
|
||||
|
||||
def train_and_test(testnet, optimizer, loss, trainset, testset):
|
||||
d_train = DataLoader(trainset, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
d_test = DataLoader(testset, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
|
||||
# train
|
||||
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epoch'):
|
||||
for batch, (batch_x, batch_y) in enumerate(d_train):
|
||||
optimizer.zero_grad()
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = testnet(batch_x)
|
||||
loss = loss_fn(y, batch_y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# test
|
||||
testnet.eval()
|
||||
metric = torchmetrics.Accuracy()
|
||||
with tqdm(desc='Test Batch: ') as pbar:
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d_test), total=len(d_test), desc='MetaNet Test - Batch'):
|
||||
y = testnet(batch_x)
|
||||
loss = loss_fn(y, batch_y)
|
||||
acc = metric(y.cpu(), batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
|
||||
acc = metric.compute()
|
||||
tqdm.write(f"Avg. accuracy on all data: {acc}")
|
||||
return acc
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(42)
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
mnist_train = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=True)
|
||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||
cifar10_train = CIFAR10(str(data_path), transform=CIFAR10_TRANSFORM, download=True, train=True)
|
||||
cifar10_test = CIFAR10(str(data_path), transform=CIFAR10_TRANSFORM, download=True, train=False)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
frame = pd.DataFrame(columns=['Dataset', 'Neurons', 'Layers', 'Parameters', 'Accuracy'])
|
||||
|
||||
for name, trainset, testset in [("MNIST",mnist_train,mnist_test), ("CIFAR10",cifar10_train,cifar10_test)]:
|
||||
best_acc = 0
|
||||
neuron_count = 0
|
||||
layer_count = 0
|
||||
|
||||
# find upper bound (in steps of 10, neurons/layer > 200 will start back from 10 with layers+1)
|
||||
while best_acc <= 0.95:
|
||||
neuron_count += 10
|
||||
if neuron_count >= 210:
|
||||
neuron_count = 10
|
||||
layer_count += 1
|
||||
net = MetaNetCompareBaseline(100, layer_count, neuron_count, out=10)
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||
acc = train_and_test(net, optimizer, loss_fn, trainset, testset)
|
||||
if acc > best_acc:
|
||||
best_acc = acc
|
||||
|
||||
num_params = sum(p.numel() for p in net._meta_layer_list.parameters())
|
||||
frame.loc[frame.shape[0]] = dict(Dataset=name, Neurons=neuron_count, Layers=layer_count, Parameters=num_params, Accuracy=acc)
|
||||
print(f"> {name}\t| {neuron_count} neurons\t| {layer_count} h.-layer(s)\t| {num_params} params\n")
|
||||
|
||||
print(frame)
|
||||
pickle.dump(frame, "min_net_search_df.pkl")
|
654
network.py
Normal file
@ -0,0 +1,654 @@
|
||||
# from __future__ import annotations
|
||||
import copy
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import optim, Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
def xavier_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight.data)
|
||||
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
|
||||
|
||||
class FixTypes:
|
||||
|
||||
divergent = 'Divergend'
|
||||
fix_zero = 'All Zero'
|
||||
identity_func = 'Self-Replicator'
|
||||
fix_sec = 'Self-Replicator 2nd'
|
||||
other_func = 'Other'
|
||||
|
||||
@classmethod
|
||||
def all_types(cls):
|
||||
return [val for key, val in cls.__dict__.items() if isinstance(val, str) and not key.startswith('_')]
|
||||
|
||||
|
||||
class NetworkLevel:
|
||||
|
||||
all = 'All'
|
||||
layer = 'Layer'
|
||||
cell = 'Cell'
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def create_target_weights(input_weight_matrix: Tensor) -> Tensor:
|
||||
""" Outputting a tensor with the target weights. """
|
||||
|
||||
# What kind of slow shit is this?
|
||||
# target_weight_matrix = np.arange(len(input_weight_matrix)).reshape(len(input_weight_matrix), 1).astype("f")
|
||||
# for i in range(len(input_weight_matrix)):
|
||||
# target_weight_matrix[i] = input_weight_matrix[i][0]
|
||||
|
||||
# Fast and simple
|
||||
target_weights = input_weight_matrix[:, 0].detach().unsqueeze(-1)
|
||||
return target_weights
|
||||
|
||||
|
||||
@staticmethod
|
||||
def are_weights_diverged(network_weights):
|
||||
""" Testing if the weights are eiter converging to infinity or -infinity. """
|
||||
|
||||
# Slow and shitty:
|
||||
# for layer_id, layer in enumerate(network_weights):
|
||||
# for cell_id, cell in enumerate(layer):
|
||||
# for weight_id, weight in enumerate(cell):
|
||||
# if torch.isnan(weight):
|
||||
# return True
|
||||
# if torch.isinf(weight):
|
||||
# return True
|
||||
# return False
|
||||
# Fast and modern:
|
||||
return any(x.isnan.any() or x.isinf().any() for x in network_weights.parameters)
|
||||
|
||||
def apply_weights(self, new_weights: Tensor):
|
||||
""" Changing the weights of a network to new given values. """
|
||||
with torch.no_grad():
|
||||
i = 0
|
||||
for parameters in self.parameters():
|
||||
size = parameters.numel()
|
||||
parameters[:] = new_weights[i:i+size].view(parameters.shape)[:]
|
||||
i += size
|
||||
return self
|
||||
|
||||
def __init__(self, i_size: int, h_size: int, o_size: int, name=None, start_time=1, lr=0.004) -> None:
|
||||
super().__init__()
|
||||
self.start_time = start_time
|
||||
|
||||
self.name = name
|
||||
self.child_nets = []
|
||||
|
||||
self.input_size = i_size
|
||||
self.hidden_size = h_size
|
||||
self.out_size = o_size
|
||||
|
||||
self.no_weights = h_size * (i_size + h_size * (h_size - 1) + o_size)
|
||||
|
||||
""" Data saved in self.s_train_weights_history & self.s_application_weights_history is used for experiments. """
|
||||
self.s_train_weights_history = []
|
||||
self.s_application_weights_history = []
|
||||
self.loss_history = []
|
||||
self.trained = False
|
||||
self.number_trained = 0
|
||||
|
||||
self.is_fixpoint = FixTypes.other_func
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Linear(i_size, h_size, False),
|
||||
nn.Linear(h_size, h_size, False),
|
||||
nn.Linear(h_size, o_size, False)]
|
||||
)
|
||||
|
||||
self._weight_pos_enc_and_mask = None
|
||||
self.apply(xavier_init)
|
||||
self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9)
|
||||
|
||||
@property
|
||||
def _weight_pos_enc(self):
|
||||
if self._weight_pos_enc_and_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
weight_matrix = []
|
||||
with torch.no_grad():
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
x = next(layer.parameters())
|
||||
weight_matrix.append(
|
||||
torch.cat(
|
||||
(
|
||||
# Those are the weights
|
||||
torch.full((x.numel(), 1), 0, device=d, requires_grad=False),
|
||||
# Layer enumeration
|
||||
torch.full((x.numel(), 1), layer_id, device=d, requires_grad=False),
|
||||
# Cell Enumeration
|
||||
torch.arange(layer.out_features, device=d, requires_grad=False
|
||||
).repeat_interleave(layer.in_features).view(-1, 1),
|
||||
# Weight Enumeration within the Cells
|
||||
torch.arange(layer.in_features, device=d, requires_grad=False
|
||||
).view(-1, 1).repeat(layer.out_features, 1),
|
||||
*(torch.full((x.numel(), 1), 0, device=d, requires_grad=False
|
||||
) for _ in range(self.input_size-4))
|
||||
), dim=1)
|
||||
)
|
||||
# Finalize
|
||||
weight_matrix = torch.cat(weight_matrix).float()
|
||||
|
||||
# Normalize 1,2,3 column of dim 1
|
||||
last_pos_idx = self.input_size - 4
|
||||
max_per_col, _ = weight_matrix[:, 1:-last_pos_idx].max(keepdim=True, dim=0)
|
||||
max_per_col += 1e-8
|
||||
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / max_per_col) + 1e-8
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(weight_matrix, requires_grad=False)
|
||||
mask[:, 0] = 0
|
||||
|
||||
self._weight_pos_enc_and_mask = weight_matrix.detach(), mask.detach()
|
||||
return self._weight_pos_enc_and_mask
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def normalize(self, value, norm):
|
||||
raise NotImplementedError
|
||||
# FIXME, This is bullshit, the code does not do what the docstring explains
|
||||
# Obsolete now
|
||||
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
|
||||
|
||||
if norm > 1:
|
||||
return float(value) / float(norm)
|
||||
else:
|
||||
return float(value)
|
||||
|
||||
def input_weight_matrix(self) -> Tensor:
|
||||
""" Calculating the input tensor formed from the weights of the net """
|
||||
with torch.no_grad():
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
pos_enc, mask = self._weight_pos_enc
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
|
||||
return weight_matrix.detach()
|
||||
|
||||
def target_weight_matrix(self) -> Tensor:
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
return weight_matrix
|
||||
|
||||
def self_train(self,
|
||||
training_steps: int,
|
||||
log_step_size: int = 0,
|
||||
save_history: bool = False,
|
||||
reduction: str = 'mean'
|
||||
) -> (Tensor, list):
|
||||
""" Training a network to predict its own weights in order to self-replicate. """
|
||||
|
||||
for training_step in range(training_steps):
|
||||
self.number_trained += 1
|
||||
self.optimizer.zero_grad()
|
||||
input_data = self.input_weight_matrix()
|
||||
target_data = self.create_target_weights(input_data)
|
||||
output = self(input_data)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
if save_history:
|
||||
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
|
||||
# If it is a soup/mixed env. save weights only at the end of all training steps (aka a soup/mixed epoch)
|
||||
if "soup" not in self.name and "mixed" not in self.name:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# If self-training steps are lower than 10, then append weight history after each ST step.
|
||||
if self.number_trained < 10:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
else:
|
||||
if log_step_size != 0:
|
||||
if self.number_trained % log_step_size == 0:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
# Saving weights only at the end of a soup/mixed exp. epoch.
|
||||
if save_history:
|
||||
if "soup" in self.name or "mixed" in self.name:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
self.trained = True
|
||||
return loss.detach(), self.loss_history
|
||||
|
||||
def self_application(self, SA_steps: int, log_step_size: Union[int, None] = None):
|
||||
""" Inputting the weights of a network to itself for a number of steps, without backpropagation. """
|
||||
|
||||
for i in range(SA_steps):
|
||||
output = self(self.input_weight_matrix())
|
||||
|
||||
# Saving the weights history after a certain amount of steps (aka log_step_size) for research purposes.
|
||||
# If self-application steps are lower than 10, then append weight history after each SA step.
|
||||
if SA_steps < 10:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
self.s_application_weights_history.append(weights.T.detach().numpy())
|
||||
else:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
if i % log_step_size == 0:
|
||||
self.s_application_weights_history.append(weights.T.detach().numpy())
|
||||
|
||||
""" See after how many steps of SA is the output not changing anymore: """
|
||||
# print(f"Self-app. step {i+1}: {Experiment.changing_rate(output2, output)}")
|
||||
|
||||
_ = self.apply_weights(output)
|
||||
|
||||
return self
|
||||
|
||||
def attack(self, other_net):
|
||||
other_net_weights = other_net.input_weight_matrix()
|
||||
my_evaluation = self(other_net_weights)
|
||||
return other_net.apply_weights(my_evaluation)
|
||||
|
||||
def melt(self, other_net):
|
||||
try:
|
||||
melted_name = self.name + other_net.name
|
||||
except AttributeError:
|
||||
melted_name = None
|
||||
melted_weights = self.create_target_weights(other_net.input_weight_matrix())
|
||||
self_weights = self.create_target_weights(self.input_weight_matrix())
|
||||
weight_indxs = list(range(len(self_weights)))
|
||||
random.shuffle(weight_indxs)
|
||||
for weight_idx in weight_indxs[:len(melted_weights) // 2]:
|
||||
melted_weights[weight_idx] = self_weights[weight_idx]
|
||||
melted_net = Net(i_size=self.input_size, h_size=self.hidden_size, o_size=self.out_size, name=melted_name)
|
||||
melted_net.apply_weights(melted_weights)
|
||||
return melted_net
|
||||
|
||||
def apply_noise(self, noise_size: float):
|
||||
""" Changing the weights of a network to values + noise """
|
||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
|
||||
# network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
|
||||
if prng() < 0.5:
|
||||
self.state_dict()[layer_name][line_id][weight_id] = weight_value + noise_size * prng()
|
||||
else:
|
||||
self.state_dict()[layer_name][line_id][weight_id] = weight_value - noise_size * prng()
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class SecondaryNet(Net):
|
||||
|
||||
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (pd.DataFrame, Tensor, list):
|
||||
""" Training a network to predict its own weights in order to self-replicate. """
|
||||
|
||||
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
|
||||
df = pd.DataFrame(columns=['step', 'loss', 'first_to_target_loss', 'second_to_target_loss', 'second_to_first_loss'])
|
||||
is_diverged = False
|
||||
for training_step in range(training_steps):
|
||||
self.number_trained += 1
|
||||
optimizer.zero_grad()
|
||||
input_data = self.input_weight_matrix()
|
||||
target_data = self.create_target_weights(input_data)
|
||||
|
||||
intermediate_output = self(input_data)
|
||||
second_input = copy.deepcopy(input_data)
|
||||
second_input[:, 0] = intermediate_output.squeeze()
|
||||
|
||||
output = self(second_input)
|
||||
second_to_target_loss = F.mse_loss(output, target_data)
|
||||
first_to_target_loss = F.mse_loss(intermediate_output, target_data * -1)
|
||||
second_to_first_loss = F.mse_loss(intermediate_output, output)
|
||||
if any([torch.isnan(x) or torch.isinf(x) for x in [second_to_first_loss, first_to_target_loss, second_to_target_loss]]):
|
||||
print('is nan')
|
||||
is_diverged = True
|
||||
break
|
||||
|
||||
loss = second_to_target_loss + first_to_target_loss
|
||||
df.loc[df.shape[0]] = [df.shape[0], loss.detach().numpy().item(),
|
||||
first_to_target_loss.detach().numpy().item(),
|
||||
second_to_target_loss.detach().numpy().item(),
|
||||
second_to_first_loss.detach().numpy().item()]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
self.trained = True
|
||||
return df, is_diverged
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface, weight_interface=5, weight_hidden_size=2, weight_output_size=1):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.weight_interface = weight_interface
|
||||
self.net_hidden_size = weight_hidden_size
|
||||
self.net_ouput_size = weight_output_size
|
||||
self.meta_weight_list = nn.ModuleList(
|
||||
[Net(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
|
||||
) for weight_idx in range(self.interface)]
|
||||
)
|
||||
self.__bed_mask = None
|
||||
|
||||
@property
|
||||
def _bed_mask(self):
|
||||
if self.__bed_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
embedding = torch.zeros(1, self.weight_interface, device=d, requires_grad=False)
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(embedding, requires_grad=False, device=d)
|
||||
mask[:, -1] = 0
|
||||
|
||||
self.__bed_mask = embedding, mask
|
||||
return self.__bed_mask
|
||||
|
||||
def forward(self, x):
|
||||
embedding, mask = self._bed_mask
|
||||
expanded_mask = mask.expand(*x.shape, embedding.shape[-1])
|
||||
embedding = embedding.expand(*x.shape, embedding.shape[-1])
|
||||
# embedding = embedding.repeat(*x.shape, 1)
|
||||
|
||||
# Row-wise
|
||||
# xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1]).swapdims(0, 1)
|
||||
# Column-wise
|
||||
xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1])
|
||||
xs = embedding * expanded_mask + xs * (1 - expanded_mask)
|
||||
# ToDo Speed this up!
|
||||
tensor = torch.hstack([meta_weight(xs[:, idx, :]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
|
||||
tensor = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (net for net in self.meta_weight_list)
|
||||
|
||||
def make_particles_attack(self, ratio=0.01):
|
||||
random_particle_list = list(self.particles)
|
||||
random.shuffle(random_particle_list)
|
||||
for idx, particle in enumerate(self.particles):
|
||||
if random.random() <= ratio:
|
||||
other = random_particle_list[idx]
|
||||
if other != particle:
|
||||
particle.attack(other)
|
||||
|
||||
def make_particles_melt(self, ratio=0.01):
|
||||
random_particle_list = list(self.particles)
|
||||
random.shuffle(random_particle_list)
|
||||
for idx, particle in enumerate(self.particles):
|
||||
if random.random() <= ratio:
|
||||
other = random_particle_list[idx]
|
||||
if other != particle:
|
||||
new_particle = particle.melt(other)
|
||||
particle.apply_weights(new_particle.target_weight_matrix())
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, width=4, # residual_skip=False,
|
||||
weight_interface=5, weight_hidden_size=2, weight_output_size=1):
|
||||
super().__init__()
|
||||
self.residual_skip = False
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
self.meta_cell_list = nn.ModuleList([
|
||||
MetaCell(name=f'{self.name}_C{cell_idx}',
|
||||
interface=interface,
|
||||
weight_interface=weight_interface, weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size,
|
||||
) for cell_idx in range(self.width)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
cell_results = []
|
||||
for metacell in self.meta_cell_list:
|
||||
cell_results.append(metacell(x))
|
||||
tensor = torch.hstack(cell_results)
|
||||
if self.residual_skip and x.shape == tensor.shape:
|
||||
tensor += x
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (weight for metacell in self.meta_cell_list for weight in metacell.particles)
|
||||
|
||||
|
||||
class MetaNet(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True, dropout=0,
|
||||
weight_interface=5, weight_hidden_size=2, weight_output_size=1,):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.dropout = dropout
|
||||
self.activation = activation
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self.weight_interface = weight_interface
|
||||
self.weight_hidden_size = weight_hidden_size
|
||||
self.weight_output_size = weight_output_size
|
||||
self._meta_layer_first = MetaLayer(name=f'L{0}',
|
||||
interface=self.interface,
|
||||
width=self.width,
|
||||
weight_interface=weight_interface,
|
||||
weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size)
|
||||
|
||||
self._meta_layer_list = nn.ModuleList([MetaLayer(name=f'L{layer_idx + 1}',
|
||||
interface=self.width, width=self.width,
|
||||
weight_interface=weight_interface,
|
||||
weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size,
|
||||
|
||||
) for layer_idx in range(self.depth - 2)]
|
||||
)
|
||||
self._meta_layer_last = MetaLayer(name=f'L{len(self._meta_layer_list) + 1}',
|
||||
interface=self.width, width=self.out,
|
||||
weight_interface=weight_interface,
|
||||
weight_hidden_size=weight_hidden_size,
|
||||
weight_output_size=weight_output_size,
|
||||
)
|
||||
self.dropout_layer = nn.Dropout(p=self.dropout)
|
||||
|
||||
def replace_with_zero(self, ident_key):
|
||||
replaced_particles = 0
|
||||
for particle in self.particles:
|
||||
if particle.is_fixpoint == ident_key:
|
||||
particle.load_state_dict(
|
||||
{key: torch.zeros_like(state) for key, state in particle.state_dict().items()}
|
||||
)
|
||||
replaced_particles += 1
|
||||
if replaced_particles != 0:
|
||||
tqdm.write(f'Particle Parameters replaced: {str(replaced_particles)}')
|
||||
return self
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self._meta_layer_first(x)
|
||||
residual = None
|
||||
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
|
||||
if idx % 2 == 1 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
residual = tensor
|
||||
tensor = meta_layer(tensor)
|
||||
if idx % 2 == 0 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
tensor = tensor + residual
|
||||
tensor = self._meta_layer_last(tensor)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (cell for metalayer in self.all_layers for cell in metalayer.particles)
|
||||
|
||||
def combined_self_train(self, n_st_steps, reduction='mean', per_particle=True, alpha=1):
|
||||
|
||||
losses = []
|
||||
|
||||
if per_particle:
|
||||
for particle in self.particles:
|
||||
loss, _ = particle.self_train(n_st_steps, reduction=reduction)
|
||||
losses.append(loss.detach())
|
||||
else:
|
||||
optim = torch.optim.SGD(self.parameters(), lr=0.004, momentum=0.9)
|
||||
for _ in range(n_st_steps):
|
||||
optim.zero_grad()
|
||||
train_losses = []
|
||||
for particle in self.particles:
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
|
||||
train_losses.append(loss)
|
||||
train_losses = torch.hstack(train_losses).sum(dim=-1, keepdim=True)
|
||||
if alpha not in [0, 1]:
|
||||
train_losses *= alpha
|
||||
train_losses.backward()
|
||||
optim.step()
|
||||
losses.append(train_losses.detach())
|
||||
losses = torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
return losses
|
||||
|
||||
@property
|
||||
def hyperparams(self):
|
||||
return {key: val for key, val in self.__dict__.items() if not key.startswith('_')}
|
||||
|
||||
def replace_particles(self, particle_weights_list):
|
||||
for layer in self.all_layers:
|
||||
for cell in layer.meta_cell_list:
|
||||
# Individual replacement on cell lvl
|
||||
for weight in cell.meta_weight_list:
|
||||
weight.apply_weights(next(particle_weights_list).detach())
|
||||
return self
|
||||
|
||||
def make_particles_attack(self, ratio=0.01, level=NetworkLevel.cell, reduction='mean'):
|
||||
if level == NetworkLevel.all:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.layer:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.cell:
|
||||
for layer in self.all_layers:
|
||||
for cell in layer.meta_cell_list:
|
||||
cell.make_particles_attack(ratio)
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(f'level has to be any of: {[level]}')
|
||||
# Self Train Loss after attack:
|
||||
with torch.no_grad():
|
||||
sa_losses = []
|
||||
for particle in self.particles:
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
|
||||
sa_losses.append(loss)
|
||||
after_attack_loss = torch.hstack(sa_losses).sum(dim=-1, keepdim=True)
|
||||
return after_attack_loss
|
||||
|
||||
def make_particles_melt(self, ratio=0.01, level=NetworkLevel.cell, reduction='mean'):
|
||||
if level == NetworkLevel.all:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.layer:
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
elif level == NetworkLevel.cell:
|
||||
for layer in self.all_layers:
|
||||
for cell in layer.meta_cell_list:
|
||||
cell.make_particles_melt(ratio)
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ValueError(f'level has to be any of: {[level]}')
|
||||
# Self Train Loss after attack:
|
||||
with torch.no_grad():
|
||||
sa_losses = []
|
||||
for particle in self.particles:
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
loss = F.mse_loss(output, target_data, reduction=reduction)
|
||||
|
||||
sa_losses.append(loss)
|
||||
after_melt_loss = torch.hstack(sa_losses).sum(dim=-1, keepdim=True)
|
||||
return after_melt_loss
|
||||
|
||||
|
||||
@property
|
||||
def all_layers(self):
|
||||
return (x for x in (self._meta_layer_first, *self._meta_layer_list, self._meta_layer_last))
|
||||
|
||||
@property
|
||||
def particle_parameter_count(self):
|
||||
return sum(p.numel() for p in next(self.particles).parameters())
|
||||
|
||||
def count_fixpoints(self, fix_type=FixTypes.identity_func):
|
||||
return sum(x.is_fixpoint == fix_type for x in self.particles)
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for particle in self.particles:
|
||||
if particle.is_fixpoint == FixTypes.divergent:
|
||||
particle.apply(xavier_init)
|
||||
|
||||
|
||||
class MetaNetCompareBaseline(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.activation = activation
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
self._first_layer = nn.Linear(self.interface, self.width, bias=False)
|
||||
self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False
|
||||
) for _ in range(self.depth - 2)])
|
||||
self._last_layer = nn.Linear(self.width, self.out, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
tensor = self._first_layer(x)
|
||||
residual = None
|
||||
for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
|
||||
if idx % 2 == 1 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
residual = tensor
|
||||
tensor = meta_layer(tensor)
|
||||
if idx % 2 == 0 and self.residual_skip:
|
||||
# if self.residual_skip:
|
||||
tensor = tensor + residual
|
||||
tensor = self._last_layer(tensor)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def all_layers(self):
|
||||
return (x for x in (self._first_layer, *self._meta_layer_list, self._last_layer))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1, residual_skip=True)
|
||||
next(metanet.particles).input_weight_matrix()
|
||||
metanet(torch.hstack([torch.full((2, 1), 1.0) for _ in range(metanet.interface)]))
|
||||
a = metanet.particles
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
119
plot_3d_trajectories.py
Normal file
@ -0,0 +1,119 @@
|
||||
import pandas as pd
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from network import FixTypes
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
|
||||
def plot_single_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):
|
||||
"""
|
||||
This plots one PCA for every net (over its n epochs) as one trajectory
|
||||
and then combines all of them in one plot
|
||||
"""
|
||||
model = torch.load(model_path, map_location=torch.device('cpu')).eval()
|
||||
all_weights = pd.read_csv(all_weights_path, index_col=False)
|
||||
save_path = model_path.parent / 'trajec_plots'
|
||||
|
||||
all_epochs = all_weights.Epoch.unique()
|
||||
pca = PCA(n_components=2, whiten=True)
|
||||
save_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for layer_idx, model_layer in enumerate(model.all_layers):
|
||||
|
||||
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
|
||||
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
|
||||
if num_status_of_layer != 0:
|
||||
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
|
||||
weight_batches = [np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
|
||||
for name in layer.Weight.unique()]
|
||||
plt.clf()
|
||||
fig = plt.figure()
|
||||
ax = plt.axes(projection='3d')
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
plt.tight_layout()
|
||||
|
||||
for weights_of_net, status in zip(weight_batches, fixpoint_statuses):
|
||||
if status == status_type:
|
||||
pca.fit(weights_of_net)
|
||||
transformed_trajectory = pca.transform(weights_of_net)
|
||||
xdata = transformed_trajectory[:, 0]
|
||||
ydata = transformed_trajectory[:, 1]
|
||||
zdata = all_epochs
|
||||
ax.plot3D(xdata, ydata, zdata)
|
||||
ax.scatter(xdata, ydata, zdata, s=7)
|
||||
|
||||
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
|
||||
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
|
||||
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
|
||||
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
|
||||
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}.png"
|
||||
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
|
||||
plt.clf()
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_grouped_3d_trajectories_by_layer(model_path, all_weights_path, status_type: FixTypes):
|
||||
""" This computes the PCA over all the net-weights at once and then plots that."""
|
||||
|
||||
model = torch.load(model_path, map_location=torch.device('cpu')).eval()
|
||||
save_path = model_path.parent / 'trajec_plots'
|
||||
all_weights = pd.read_csv(all_weights_path, index_col=False)
|
||||
all_epochs = all_weights.Epoch.unique()
|
||||
pca = PCA(n_components=2, whiten=True)
|
||||
save_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
for layer_idx, model_layer in enumerate(model.all_layers):
|
||||
|
||||
fixpoint_statuses = [net.is_fixpoint for net in model_layer.particles]
|
||||
num_status_of_layer = sum([net.is_fixpoint == status_type for net in model_layer.particles])
|
||||
if num_status_of_layer != 0:
|
||||
layer = all_weights[all_weights.Weight.str.startswith(f"L{layer_idx}")]
|
||||
weight_batches = np.vstack([np.array(layer[layer.Weight == name].values.tolist())[:, 2:]
|
||||
for name in layer.Weight.unique()])
|
||||
plt.clf()
|
||||
fig = plt.figure()
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
ax = plt.axes(projection='3d')
|
||||
plt.tight_layout()
|
||||
|
||||
pca.fit(weight_batches)
|
||||
w_transformed = pca.transform(weight_batches)
|
||||
for transformed_trajectory, status in zip(
|
||||
np.split(w_transformed, len(layer.Weight.unique())), fixpoint_statuses):
|
||||
if status == status_type:
|
||||
xdata = transformed_trajectory[:, 0]
|
||||
ydata = transformed_trajectory[:, 1]
|
||||
zdata = all_epochs
|
||||
ax.plot3D(xdata, ydata, zdata)
|
||||
ax.scatter(xdata, ydata, zdata, s=7)
|
||||
|
||||
ax.set_title(f"Layer {layer_idx}: {num_status_of_layer}-{status_type}", fontsize=20)
|
||||
ax.set_xlabel('PCA Transformed x-axis', fontsize=20)
|
||||
ax.set_ylabel('PCA Transformed y-axis', fontsize=20)
|
||||
ax.set_zlabel('Epochs', fontsize=30, rotation=0)
|
||||
file_path = save_path / f"layer_{layer_idx}_{num_status_of_layer}_{status_type}_grouped.png"
|
||||
plt.savefig(file_path, bbox_inches="tight", dpi=300, format="png")
|
||||
plt.clf()
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise (NotImplementedError('Get out of here'))
|
||||
"""
|
||||
weight_path = Path("weight_store.csv")
|
||||
model_path = Path("trained_model_ckpt_e100.tp")
|
||||
save_path = Path("figures/3d_trajectories/")
|
||||
|
||||
weight_df = pd.read_csv(weight_path)
|
||||
weight_df = weight_df.drop_duplicates(subset=['Weight','Epoch'])
|
||||
model = torch.load(model_path, map_location=torch.device('cpu'))
|
||||
|
||||
plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.identity_func)
|
||||
plot_single_3d_trajectories_by_layer(model, weight_df, save_path, status_type=FixTypes.other_func)
|
||||
plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.identity_func)
|
||||
#plot_grouped_3d_trajectories_by_layer(model, weight_df, save_path, FixTypes.other_func)
|
||||
"""
|
14
requirements.txt
Normal file
@ -0,0 +1,14 @@
|
||||
torch~=1.8.1+cpu
|
||||
tqdm~=4.60.0
|
||||
numpy~=1.20.3
|
||||
matplotlib~=3.4.2
|
||||
sklearn~=0.0
|
||||
scipy
|
||||
tabulate~=0.8.9
|
||||
|
||||
scikit-learn~=0.24.2
|
||||
pandas~=1.2.4
|
||||
seaborn~=0.11.1
|
||||
future~=0.18.2
|
||||
torchmetrics~=0.7.0
|
||||
torchvision~=0.9.1+cpu
|
22
sanity_check_particle_weight_swap.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from network import MetaNet
|
||||
from sparse_net import SparseNetwork
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dense_metanet = MetaNet(30, depth=5, width=6, out=10, residual_skip=True,
|
||||
weight_hidden_size=3, )
|
||||
sparse_metanet = SparseNetwork(30, depth=5, width=6, out=10, residual_skip=True,
|
||||
weight_hidden_size=3,)
|
||||
|
||||
particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||
|
||||
# Transfer weights
|
||||
sparse_metanet = sparse_metanet.replace_weights_by_particles(dense_metanet.particles)
|
||||
|
||||
# Transfer weights
|
||||
dense_metanet = dense_metanet.replace_particles(sparse_metanet.particle_weights)
|
||||
new_particles = [torch.cat([x.view(-1) for x in x.parameters()]) for x in dense_metanet.particles]
|
||||
|
||||
print(f' Particles are same: {all([(x==y).all() for x,y in zip(particles, new_particles) ])}')
|
71
sanity_check_weights.py
Normal file
@ -0,0 +1,71 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST, CIFAR10
|
||||
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
|
||||
import torchmetrics
|
||||
|
||||
from functionalities_test import epsilon_error_margin as e
|
||||
from network import MetaNet, MetaNetCompareBaseline
|
||||
|
||||
|
||||
def extract_weights_from_model(model: MetaNet) -> dict:
|
||||
inpt = torch.zeros(5, device=next(model.parameters()).device, dtype=torch.float)
|
||||
inpt[-1] = 1
|
||||
|
||||
weights = defaultdict(list)
|
||||
layers = [layer.particles for layer in model.all_layers]
|
||||
for i, layer in enumerate(layers):
|
||||
for net in layer:
|
||||
weights[i].append(net(inpt).detach())
|
||||
return dict(weights)
|
||||
|
||||
|
||||
def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics.Accuracy):
|
||||
meta_net_device = next(meta_net.parameters()).device
|
||||
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth,
|
||||
width=meta_net.width, out=meta_net.out,
|
||||
residual_skip=meta_net.residual_skip).to(meta_net_device)
|
||||
with torch.no_grad():
|
||||
new_weight_values = list(new_weights.values())
|
||||
old_parameters = list(transfer_net.parameters())
|
||||
assert len(new_weight_values) == len(old_parameters)
|
||||
for weights, parameters in zip(new_weights.values(), transfer_net.parameters()):
|
||||
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
|
||||
|
||||
transfer_net.eval()
|
||||
results = dict()
|
||||
for net in [meta_net, transfer_net]:
|
||||
net.eval()
|
||||
metric = metric_class()
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='Test Batch: '):
|
||||
y = net(batch_x.to(meta_net_device))
|
||||
metric(y.cpu(), batch_y.cpu())
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
measure = metric.compute()
|
||||
results[net.__class__.__name__] = measure.item()
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
WORKER = 0
|
||||
BATCHSIZE = 500
|
||||
MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Flatten(start_dim=0)])
|
||||
torch.manual_seed(42)
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
|
||||
model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
test_weights_as_model(model, weights, d_test)
|
||||
|
385
sparse_net.py
Normal file
@ -0,0 +1,385 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
from torch import nn
|
||||
|
||||
import functionalities_test
|
||||
from network import Net
|
||||
from functionalities_test import is_identity_function, test_for_fixpoints, epsilon_error_margin
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
|
||||
|
||||
def xavier_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
return nn.init.xavier_uniform_(m.weight.data)
|
||||
if isinstance(m, torch.Tensor):
|
||||
return nn.init.xavier_uniform_(m)
|
||||
|
||||
|
||||
class SparseLayer(nn.Module):
|
||||
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
|
||||
super(SparseLayer, self).__init__()
|
||||
|
||||
self.nr_nets = nr_nets
|
||||
self.interface_dim = interface
|
||||
self.depth_dim = depth
|
||||
self.hidden_dim = width
|
||||
self.out_dim = out
|
||||
dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)
|
||||
self.dummy_net_shapes = [list(x.shape) for x in dummy_net.parameters()]
|
||||
self.dummy_net_weight_pos_enc = dummy_net._weight_pos_enc
|
||||
|
||||
self.sparse_sub_layer = list()
|
||||
self.indices = list()
|
||||
self.diag_shapes = list()
|
||||
self.weights = nn.ParameterList()
|
||||
self._particles = None
|
||||
|
||||
for layer_id in range(self.depth_dim):
|
||||
indices, weights, diag_shape = self.coo_sparse_layer(layer_id)
|
||||
self.indices.append(indices)
|
||||
self.diag_shapes.append(diag_shape)
|
||||
self.weights.append(weights)
|
||||
self.apply(xavier_init)
|
||||
|
||||
def coo_sparse_layer(self, layer_id):
|
||||
with torch.no_grad():
|
||||
layer_shape = self.dummy_net_shapes[layer_id]
|
||||
sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)
|
||||
indices = torch.Tensor(np.argwhere(sparse_diagonal == 1).T, )
|
||||
values = torch.nn.Parameter(torch.randn((np.prod((*layer_shape, self.nr_nets)).item())), requires_grad=True)
|
||||
|
||||
return indices, values, sparse_diagonal.shape
|
||||
|
||||
def get_self_train_inputs_and_targets(self):
|
||||
# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN
|
||||
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2
|
||||
# and first hidden*out weights of layer3 = first net
|
||||
# [nr_layers*[nr_net*nr_weights_layer_i]]
|
||||
with torch.no_grad():
|
||||
weights = [layer.view(-1, int(len(layer)/self.nr_nets)).detach() for layer in self.weights]
|
||||
# [nr_net*[nr_weights]]
|
||||
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1, 1) for i in range(self.nr_nets)]
|
||||
# (16, 25)
|
||||
|
||||
encoding_matrix, mask = self.dummy_net_weight_pos_enc
|
||||
weight_device = weights_per_net[0].device
|
||||
if weight_device != encoding_matrix.device or weight_device != mask.device:
|
||||
encoding_matrix, mask = encoding_matrix.to(weight_device), mask.to(weight_device)
|
||||
self.dummy_net_weight_pos_enc = encoding_matrix, mask
|
||||
|
||||
inputs = torch.hstack(
|
||||
[encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask)
|
||||
for i in range(self.nr_nets)]
|
||||
)
|
||||
targets = torch.hstack(weights_per_net)
|
||||
return inputs.T, targets.T
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
if self._particles is None:
|
||||
self._particles = [Net(self.interface_dim, self.hidden_dim, self.out_dim) for _ in range(self.nr_nets)]
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
# Particle Weight Update
|
||||
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
weights_per_net = [torch.cat([layer[i] for layer in all_weights]).view(-1, 1) for i in
|
||||
range(self.nr_nets)] # [nr_net*[nr_weights]]
|
||||
for particles, weights in zip(self._particles, weights_per_net):
|
||||
particles.apply_weights(weights)
|
||||
return self._particles
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for weights in self.weights:
|
||||
if torch.isinf(weights).any() or torch.isnan(weights).any():
|
||||
with torch.no_grad():
|
||||
where_nan = torch.nan_to_num(weights, -99, -99, -99)
|
||||
mask = torch.where(where_nan == -99, 0, 1)
|
||||
weights[:] = (where_nan * mask + torch.randn_like(weights) * (1 - mask))[:]
|
||||
|
||||
@property
|
||||
def particle_weights(self):
|
||||
all_weights = [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
weights_per_net = [torch.cat([layer[i] for layer in all_weights]).view(-1, 1) for i in
|
||||
range(self.nr_nets)] # [nr_net*[nr_weights]]
|
||||
return weights_per_net
|
||||
|
||||
def replace_weights_by_particles(self, particles):
|
||||
assert len(particles) == self.nr_nets
|
||||
with torch.no_grad():
|
||||
# Particle Weight Update
|
||||
all_weights = [list(particle.parameters()) for particle in particles]
|
||||
all_weights = [torch.cat(x).view(-1) for x in zip(*all_weights)]
|
||||
# [layer.view(-1, int(len(layer) / self.nr_nets)) for layer in self.weights]
|
||||
for weights, parameters in zip(all_weights, self.parameters()):
|
||||
parameters[:] = weights[:]
|
||||
return self
|
||||
|
||||
def __call__(self, x):
|
||||
for indices, diag_shapes, weights in zip(self.indices, self.diag_shapes, self.weights):
|
||||
s = torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
|
||||
x = torch.sparse.mm(s, x)
|
||||
return x
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super(SparseLayer, self).to(*args, **kwargs)
|
||||
self.sparse_sub_layer = [sparse_sub_layer.to(*args, **kwargs) for sparse_sub_layer in self.sparse_sub_layer]
|
||||
return self
|
||||
|
||||
|
||||
def test_sparse_layer():
|
||||
net = SparseLayer(1000)
|
||||
loss_fn = torch.nn.MSELoss(reduction='mean')
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.008, momentum=0.9)
|
||||
# optimizer = torch.optim.SGD([layer.coalesce().values() for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
|
||||
df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
|
||||
train_iterations = 20000
|
||||
|
||||
for train_iteration in trange(train_iterations):
|
||||
optimizer.zero_grad()
|
||||
X, Y = net.get_self_train_inputs_and_targets()
|
||||
output = net(X)
|
||||
|
||||
loss = loss_fn(output, Y) * 100
|
||||
|
||||
# loss = sum([loss_fn(out, target) for out, target in zip(output, Y)]) / len(output) * 10
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if train_iteration % 500 == 0:
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {train_iteration + 1} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
df.loc[df.shape[0]] = (train_iteration, key, value)
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = functionalities_test.test_for_fixpoints(counter, list(net.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {train_iterations} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
df.loc[df.shape[0]] = (train_iterations, key, value)
|
||||
df.to_csv('counter.csv', mode='w')
|
||||
|
||||
c = pd.read_csv('counter.csv', index_col=0)
|
||||
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
|
||||
plt.savefig('counter.png', dpi=300)
|
||||
|
||||
|
||||
def embed_batch(x, repeat_dim):
|
||||
# x of shape (batchsize, flat_img_dim)
|
||||
|
||||
# (batchsize, flat_img_dim, 1)
|
||||
x = x.unsqueeze(-1)
|
||||
# (batchsize, flat_img_dim, encoding_dim*repeat_dim)
|
||||
# torch.sparse_coo_tensor(indices, weights, diag_shapes, requires_grad=True, device=x.device)
|
||||
return torch.cat((torch.zeros(x.shape[0], x.shape[1], 4, device=x.device), x), dim=2).repeat(1, 1, repeat_dim)
|
||||
|
||||
def embed_vector(x, repeat_dim):
|
||||
# x of shape [flat_img_dim]
|
||||
x = x.unsqueeze(-1) # (flat_img_dim, 1)
|
||||
# (flat_img_dim, encoding_dim*repeat_dim)
|
||||
return torch.cat((torch.zeros(x.shape[0], 4), x), dim=1).repeat(1,repeat_dim)
|
||||
|
||||
|
||||
class SparseNetwork(nn.Module):
|
||||
|
||||
@property
|
||||
def nr_nets(self):
|
||||
return sum(x.nr_nets for x in self.sparselayers)
|
||||
|
||||
def __init__(self, input_dim, depth, width, out, residual_skip=True, activation=None,
|
||||
weight_interface=5, weight_hidden_size=2, weight_output_size=1
|
||||
):
|
||||
super(SparseNetwork, self).__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.input_dim = input_dim
|
||||
self.depth_dim = depth
|
||||
self.hidden_dim = width
|
||||
self.out_dim = out
|
||||
self.activation = activation
|
||||
self.first_layer = SparseLayer(self.input_dim * self.hidden_dim,
|
||||
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
|
||||
self.last_layer = SparseLayer(self.hidden_dim * self.out_dim,
|
||||
interface=weight_interface, width=weight_hidden_size, out=weight_output_size)
|
||||
self.hidden_layers = nn.ModuleList([
|
||||
SparseLayer(self.hidden_dim * self.hidden_dim,
|
||||
interface=weight_interface, width=weight_hidden_size, out=weight_output_size
|
||||
) for _ in range(self.depth_dim - 2)])
|
||||
|
||||
def __call__(self, x):
|
||||
|
||||
tensor = self.sparse_layer_forward(x, self.first_layer)
|
||||
if self.activation:
|
||||
tensor = self.activation(tensor)
|
||||
for nl_idx, network_layer in enumerate(self.hidden_layers):
|
||||
# if idx % 2 == 1 and self.residual_skip:
|
||||
if self.residual_skip:
|
||||
residual = tensor
|
||||
tensor = self.sparse_layer_forward(tensor, network_layer)
|
||||
# if idx % 2 == 0 and self.residual_skip:
|
||||
if self.residual_skip:
|
||||
tensor = tensor + residual
|
||||
tensor = self.sparse_layer_forward(tensor, self.last_layer, view_dim=self.out_dim)
|
||||
return tensor
|
||||
|
||||
def sparse_layer_forward(self, x, sparse_layer, view_dim=None):
|
||||
view_dim = view_dim if view_dim else self.hidden_dim
|
||||
# batch pass (one by one, sparse bmm doesn't support grad)
|
||||
if len(x.shape) > 1:
|
||||
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
|
||||
# [batchsize, hidden*inpt_dim, feature_dim]
|
||||
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1) for inpt in
|
||||
embedded_inpt])
|
||||
# vector
|
||||
else:
|
||||
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
|
||||
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(view_dim, x.shape[1]).sum(dim=1)
|
||||
return x
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
#particles = []
|
||||
#particles.extend(self.first_layer.particles)
|
||||
#for layer in self.hidden_layers:
|
||||
# particles.extend(layer.particles)
|
||||
#particles.extend(self.last_layer.particles)
|
||||
return (x for y in (self.first_layer.particles,
|
||||
*(l.particles for l in self.hidden_layers),
|
||||
self.last_layer.particles) for x in y)
|
||||
|
||||
@property
|
||||
def particle_weights(self):
|
||||
return (x for y in self.sparselayers for x in y.particle_weights)
|
||||
|
||||
def reset_diverged_particles(self):
|
||||
for layer in self.sparselayers:
|
||||
layer.reset_diverged_particles()
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super(SparseNetwork, self).to(*args, **kwargs)
|
||||
self.first_layer = self.first_layer.to(*args, **kwargs)
|
||||
self.last_layer = self.last_layer.to(*args, **kwargs)
|
||||
self.hidden_layers = nn.ModuleList([hidden_layer.to(*args, **kwargs) for hidden_layer in self.hidden_layers])
|
||||
return self
|
||||
|
||||
@property
|
||||
def sparselayers(self):
|
||||
return (x for x in (self.first_layer, *self.hidden_layers, self.last_layer))
|
||||
|
||||
def combined_self_train(self, optimizer, reduction='mean'):
|
||||
losses = []
|
||||
loss_fn = nn.MSELoss(reduction=reduction)
|
||||
for layer in self.sparselayers:
|
||||
optimizer.zero_grad()
|
||||
x, target_data = layer.get_self_train_inputs_and_targets()
|
||||
output = layer(x)
|
||||
# loss = sum([loss_fn(out, target) for out, target in zip(output, target_data)]) / len(output)
|
||||
|
||||
loss = loss_fn(output, target_data) * layer.nr_nets
|
||||
|
||||
losses.append(loss.detach())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def replace_weights_by_particles(self, particles):
|
||||
particles = list(particles)
|
||||
for layer in self.sparselayers:
|
||||
layer.replace_weights_by_particles(particles[:layer.nr_nets])
|
||||
del particles[:layer.nr_nets]
|
||||
return self
|
||||
|
||||
|
||||
def test_sparse_net():
|
||||
utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])
|
||||
data_path = Path('data')
|
||||
WORKER = 8
|
||||
BATCHSIZE = 10
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
data_dim = np.prod(dataset[0][0].shape)
|
||||
metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)
|
||||
batchx, batchy = next(iter(d))
|
||||
out = metanet(batchx)
|
||||
|
||||
result = sum([torch.allclose(out[i], batchy[i], rtol=0, atol=epsilon_error_margin) for i in range(metanet.nr_nets)])
|
||||
# print(f"identity_fn after {train_iteration+1} self-train iterations: {result} /{net.nr_nets}")
|
||||
|
||||
|
||||
def test_sparse_net_sef_train():
|
||||
sparse_metanet = SparseNetwork(15*15, 5, 6, 10).to('cuda')
|
||||
init_st_store_path = Path('counter.csv')
|
||||
optimizer = torch.optim.SGD(sparse_metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
init_st_epochs = 10000
|
||||
init_st_df = pd.DataFrame(columns=['Epoch', 'Func Type', 'Count'])
|
||||
|
||||
for st_epoch in trange(init_st_epochs):
|
||||
_ = sparse_metanet.combined_self_train(optimizer)
|
||||
|
||||
if st_epoch % 500 == 0:
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {st_epoch} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
init_st_df.loc[init_st_df.shape[0]] = (st_epoch, key, value)
|
||||
sparse_metanet.reset_diverged_particles()
|
||||
|
||||
counter = defaultdict(lambda: 0)
|
||||
id_functions = test_for_fixpoints(counter, list(sparse_metanet.particles))
|
||||
counter = dict(counter)
|
||||
tqdm.write(f"identity_fn after {init_st_epochs} self-train epochs: {counter}")
|
||||
for key, value in counter.items():
|
||||
init_st_df.loc[init_st_df.shape[0]] = (init_st_epochs, key, value)
|
||||
init_st_df.to_csv(init_st_store_path, mode='w', index=False)
|
||||
|
||||
c = pd.read_csv(init_st_store_path)
|
||||
sns.lineplot(data=c, x='Epoch', y='Count', hue='Func Type')
|
||||
plt.savefig(init_st_store_path, dpi=300)
|
||||
|
||||
|
||||
def test_manual_for_loop():
|
||||
nr_nets = 500
|
||||
nets = [Net(5,2,1) for _ in range(nr_nets)]
|
||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||
rounds = 1000
|
||||
|
||||
for net in tqdm(nets):
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
|
||||
for i in range(rounds):
|
||||
optimizer.zero_grad()
|
||||
input_data = net.input_weight_matrix()
|
||||
target_data = net.create_target_weights(input_data)
|
||||
output = net(input_data)
|
||||
loss = loss_fn(output, target_data)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
sum([is_identity_function(net) for net in nets])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_sparse_layer()
|
||||
test_sparse_net_sef_train()
|
||||
# test_sparse_net()
|
||||
# for comparison
|
||||
# test_manual_for_loop()
|
498
sparse_tensor_combined.ipynb
Normal file
@ -0,0 +1,498 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from network import Net\n",
|
||||
"import torch\n",
|
||||
"from typing import List\n",
|
||||
"from functionalities_test import is_identity_function\n",
|
||||
"from tqdm import tqdm,trange\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def construct_sparse_COO_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
||||
" assert layer_idx <= len(list(nets[0].parameters()))\n",
|
||||
" values = []\n",
|
||||
" indices = []\n",
|
||||
" for net_idx,net in enumerate(nets):\n",
|
||||
" layer = list(net.parameters())[layer_idx]\n",
|
||||
" \n",
|
||||
" for cell_idx,cell in enumerate(layer):\n",
|
||||
" # E.g., position of cell weights (with 2 cells per hidden layer) in first sparse layer of N nets: \n",
|
||||
" \n",
|
||||
" # [4x2 weights_net0] [4x2x(n-1) 0s]\n",
|
||||
" # [4x2 weights] [4x2 weights_net0] [4x2x(n-2) 0s]\n",
|
||||
" # ... etc\n",
|
||||
" # [4x2x(n-1) 0s] [4x2 weights_netN]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # -> 4x2 weights on the diagonal = [shifted Nr_cellss*B down for AxB cells, and Nr_nets(*A weights)to the right] \n",
|
||||
" for i in range(len(cell)):\n",
|
||||
" indices.append([len(layer)*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
||||
" #indices.append([2*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
||||
"\n",
|
||||
" [values.append(weight) for weight in cell]\n",
|
||||
"\n",
|
||||
" # for i in range(4):\n",
|
||||
" # indices.append([idx+idx+1, i+(idx*4)])\n",
|
||||
" #for l in next(net.parameters()):\n",
|
||||
" #[values.append(w) for w in l]\n",
|
||||
" #print(indices, values)\n",
|
||||
"\n",
|
||||
" #s = torch.sparse_coo_tensor(list(zip(*indices)), values, (2*nr_nets, 4*nr_nets))\n",
|
||||
" # sparse tensor dimension = (nr_cells*nr_nets , nr_weights/cell * nr_nets), i.e.,\n",
|
||||
" # layer 1: (2x4) -> (2*N, 4*N)\n",
|
||||
" # layer 2: (2x2) -> (2*N, 2*N)\n",
|
||||
" # layer 3: (1x2) -> (2*N, 1*N)\n",
|
||||
" s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets))\n",
|
||||
" #print(s.to_dense())\n",
|
||||
" #print(s.to_dense().shape)\n",
|
||||
" return s\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# for each net append to the combined sparse tensor\n",
|
||||
"# construct sparse tensor for each layer, with Nets of (4,2,1), each net appends\n",
|
||||
"# - [4x2] weights in the first (input) layer\n",
|
||||
"# - [2x2] weights in the second (hidden) layer\n",
|
||||
"# - [2x1] weights in the third (output) layer\n",
|
||||
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"#modules\n",
|
||||
"#for layer_idx in range(len(list(nets[0].parameters()))):\n",
|
||||
"# sparse_tensor = construct_sparse_tensor_layer(nets, layer_idx)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 50\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1000):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"\n",
|
||||
" X1 = torch.sparse.mm(modules[0], X)\n",
|
||||
" #print(\"X1\", X1.shape, X1)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(modules[1], X1)\n",
|
||||
" #print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(modules[2], X2)\n",
|
||||
" #print(\"X3\", X3.shape)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 500\n",
|
||||
"nets = [Net(5,2,1) for _ in range(nr_nets)]\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"rounds = 1000\n",
|
||||
"\n",
|
||||
"for net in tqdm(nets):\n",
|
||||
" optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)\n",
|
||||
" for i in range(rounds):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" input_data = net.input_weight_matrix()\n",
|
||||
" target_data = net.create_target_weights(input_data)\n",
|
||||
" output = net(input_data)\n",
|
||||
" loss = loss_fn(output, target_data)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"sum([is_identity_function(net) for net in nets])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def construct_sparse_CRS_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
||||
" assert layer_idx <= len(list(nets[0].parameters()))\n",
|
||||
" \n",
|
||||
" s = torch.cat( [\n",
|
||||
" torch.cat(\n",
|
||||
" (\n",
|
||||
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
|
||||
" list(net.parameters())[layer_idx], \n",
|
||||
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
|
||||
" )\n",
|
||||
" , dim=1) for net_idx, net in enumerate(nets)\n",
|
||||
" ]).to_sparse_csr()\n",
|
||||
"\n",
|
||||
" print(s.shape)\n",
|
||||
" return s"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" num_layers = len(list(nets[0].parameters()))\n",
|
||||
" modules = [ construct_sparse_CRS_layer(nets, layer_idx) for layer_idx in range(num_layers)]\n",
|
||||
"\n",
|
||||
" X1 = modules[0].matmul(X)\n",
|
||||
" print(\"X1\", X1.shape, X1.is_sparse)\n",
|
||||
"\n",
|
||||
" X2 = modules[1].matmul(X1)\n",
|
||||
" print(\"X2\", X2.shape, X2.is_sparse)\n",
|
||||
"\n",
|
||||
" X3 = modules[2].matmul(X2)\n",
|
||||
" print(\"X3\", X3.shape, X3.is_sparse)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 2\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"\n",
|
||||
"def cat_COO_layer(nets, layer_idx):\n",
|
||||
" i = [[0,i] for i in range(nr_nets*len(list(net.parameters())[layer_idx]))]\n",
|
||||
" v = torch.cat( [\n",
|
||||
" torch.cat(\n",
|
||||
" (\n",
|
||||
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
|
||||
" list(net.parameters())[layer_idx], \n",
|
||||
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
|
||||
" )\n",
|
||||
" , dim=1) for net_idx, net in enumerate(nets)\n",
|
||||
" ])\n",
|
||||
" #print(i,v)\n",
|
||||
" s = torch.sparse_coo_tensor(list(zip(*i)), v)\n",
|
||||
" print(s[0].to_dense().shape, s[0].is_sparse)\n",
|
||||
" return s[0]\n",
|
||||
"\n",
|
||||
"cat_COO_layer(nets, 0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"nr_layers = len(list(nets[0].parameters()))\n",
|
||||
"modules = [ cat_COO_layer(nets, layer_idx) for layer_idx in range(nr_layers) ]\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" X1 = torch.sparse.mm(modules[0], X)\n",
|
||||
" print(\"X1\", X1.shape)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(modules[1], X1)\n",
|
||||
" print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(modules[2], X2)\n",
|
||||
" print(\"X3\", X3.shape)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SparseLayer():\n",
|
||||
" def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):\n",
|
||||
" self.nr_nets = nr_nets\n",
|
||||
" self.interface_dim = interface\n",
|
||||
" self.depth_dim = depth\n",
|
||||
" self.hidden_dim = width\n",
|
||||
" self.out_dim = out\n",
|
||||
" self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)\n",
|
||||
" \n",
|
||||
" self.sparse_sub_layer = []\n",
|
||||
" self.weights = []\n",
|
||||
" for layer_id in range(depth):\n",
|
||||
" layer, weights = self.coo_sparse_layer(layer_id)\n",
|
||||
" self.sparse_sub_layer.append(layer)\n",
|
||||
" self.weights.append(weights)\n",
|
||||
" \n",
|
||||
" def coo_sparse_layer(self, layer_id):\n",
|
||||
" layer_shape = list(self.dummy_net.parameters())[layer_id].shape\n",
|
||||
" #print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)\n",
|
||||
"\n",
|
||||
" sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)\n",
|
||||
" indices = np.argwhere(sparse_diagonal == 1).T\n",
|
||||
" values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) )))\n",
|
||||
" #values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] ))\n",
|
||||
" s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True)\n",
|
||||
" print(f\"L{layer_id}:\", s.shape)\n",
|
||||
" return s, values\n",
|
||||
"\n",
|
||||
" def get_self_train_inputs_and_targets(self):\n",
|
||||
" encoding_matrix, mask = self.dummy_net._weight_pos_enc\n",
|
||||
"\n",
|
||||
" # view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
|
||||
" # i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
|
||||
" weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]]\n",
|
||||
" weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]]\n",
|
||||
" inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25)\n",
|
||||
" targets = torch.hstack(weights_per_net)\n",
|
||||
" return inputs.T, targets.T\n",
|
||||
"\n",
|
||||
" def __call__(self, x):\n",
|
||||
" X1 = torch.sparse.mm(self.sparse_sub_layer[0], x)\n",
|
||||
" #print(\"X1\", X1.shape)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1)\n",
|
||||
" #print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2)\n",
|
||||
" #print(\"X3\", X3.shape)\n",
|
||||
" \n",
|
||||
" return X3\n",
|
||||
"\n",
|
||||
"net = SparseLayer(5)\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"for train_iteration in trange(10):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X,Y = net.get_self_train_inputs_and_targets()\n",
|
||||
" out = net(X)\n",
|
||||
" \n",
|
||||
" loss = loss_fn(out, Y)\n",
|
||||
"\n",
|
||||
" # print(\"X:\", X.shape, \"Y:\", Y.shape)\n",
|
||||
" # print(\"OUT\", out.shape)\n",
|
||||
" # print(\"LOSS\", loss.item())\n",
|
||||
" \n",
|
||||
" loss.backward(retain_graph=True)\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"epsilon=pow(10, -5)\n",
|
||||
"# is the (the whole layer) self-replicating? -> wrong\n",
|
||||
"#print(torch.allclose(out, Y,rtol=0, atol=epsilon))\n",
|
||||
"\n",
|
||||
"# is each of the networks self-replicating?\n",
|
||||
"print(f\"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for layer in net.weights:\n",
|
||||
"# n=int(len(layer)/net.nr_nets)\n",
|
||||
"# print( [layer[i:i+n] for i in range(0, len(layer), n)])\n",
|
||||
"\n",
|
||||
"encoding_matrix, mask = Net(5,2,1)._weight_pos_enc\n",
|
||||
"print(encoding_matrix, mask)\n",
|
||||
"# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
|
||||
"# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
|
||||
"weights = [layer.view(-1, int(len(layer)/net.nr_nets)) for layer in net.weights]\n",
|
||||
"weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(net.nr_nets)]\n",
|
||||
"\n",
|
||||
"inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(net.nr_nets)]) #16, 25\n",
|
||||
"\n",
|
||||
"targets = torch.hstack(weights_per_net)\n",
|
||||
"targets.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from pathlib import Path\n",
|
||||
"import torch\n",
|
||||
"from torch.nn import Flatten\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision.transforms import ToTensor, Compose, Resize\n",
|
||||
"from tqdm import tqdm, trange\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])\n",
|
||||
"data_path = Path('data')\n",
|
||||
"WORKER = 8\n",
|
||||
"BATCHSIZE = 10\n",
|
||||
"EPOCH = 1\n",
|
||||
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"\n",
|
||||
"dataset = MNIST(str(data_path), transform=utility_transforms)\n",
|
||||
"d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def embed_batch(x, repeat_dim):\n",
|
||||
" # x of shape (batchsize, flat_img_dim)\n",
|
||||
" x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)\n",
|
||||
" return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)\n",
|
||||
"\n",
|
||||
"def embed_vector(x, repeat_dim):\n",
|
||||
" # x of shape [flat_img_dim]\n",
|
||||
" x = x.unsqueeze(-1) #(flat_img_dim, 1)\n",
|
||||
" return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim)\n",
|
||||
"\n",
|
||||
"class SparseNetwork():\n",
|
||||
" def __init__(self, input_dim, depth, width, out):\n",
|
||||
" self.input_dim = input_dim\n",
|
||||
" self.depth_dim = depth\n",
|
||||
" self.hidden_dim = width\n",
|
||||
" self.out_dim = out\n",
|
||||
" self.sparse_layers = []\n",
|
||||
" self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim ))\n",
|
||||
" self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)])\n",
|
||||
" self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim ))\n",
|
||||
"\n",
|
||||
" def __call__(self, x):\n",
|
||||
" \n",
|
||||
" for sparse_layer in self.sparse_layers[:-1]:\n",
|
||||
" # batch pass (one by one, sparse bmm doesn't support grad)\n",
|
||||
" if len(x.shape) > 1:\n",
|
||||
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
|
||||
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
|
||||
" # vector\n",
|
||||
" else:\n",
|
||||
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
|
||||
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1)\n",
|
||||
" print(\"out\", x.shape)\n",
|
||||
" \n",
|
||||
" # output layer\n",
|
||||
" sparse_layer = self.sparse_layers[-1]\n",
|
||||
" if len(x.shape) > 1:\n",
|
||||
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
|
||||
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
|
||||
" else:\n",
|
||||
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
|
||||
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1)\n",
|
||||
" print(\"out\", x.shape)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"data_dim = np.prod(dataset[0][0].shape)\n",
|
||||
"metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)\n",
|
||||
"batchx, batchy = next(iter(d))\n",
|
||||
"batchx.shape, batchy.shape\n",
|
||||
"metanet(batchx)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "8bcba732c17ca4dacffea8ad1176c852d4229b36b9060a5f633fff752e5396ea"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.12 64-bit ('masterthesis': conda)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.12"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
281
visualization.py
Normal file
@ -0,0 +1,281 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
import random
|
||||
import string
|
||||
|
||||
from matplotlib import rcParams
|
||||
rcParams['axes.labelpad'] = 20
|
||||
|
||||
|
||||
def plot_output(output):
|
||||
""" Plotting the values of the final output """
|
||||
plt.figure()
|
||||
plt.imshow(output)
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_loss(loss_array, directory: Union[str, Path], batch_size=1):
|
||||
""" Plotting the evolution of the loss function."""
|
||||
|
||||
fig = plt.figure()
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
|
||||
for i in range(len(loss_array)):
|
||||
plt.plot(loss_array[i], label=f"Last loss value: {str(loss_array[i][len(loss_array[i])-1])}")
|
||||
|
||||
plt.legend()
|
||||
plt.xlabel("Epochs")
|
||||
plt.ylabel("Loss")
|
||||
|
||||
directory = Path(directory)
|
||||
filename = "nets_loss_function.png"
|
||||
file_path = directory / filename
|
||||
plt.savefig(str(file_path))
|
||||
|
||||
plt.clf()
|
||||
|
||||
|
||||
def bar_chart_fixpoints(fixpoint_counter: Dict, population_size: int, directory: Union[str, Path], learning_rate: float,
|
||||
exp_details: str, source_check=None):
|
||||
""" Plotting the number of fixpoints in a barchart. """
|
||||
|
||||
fig = plt.figure()
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
|
||||
legend_population_size = mpatches.Patch(color="white", label=f"No. of nets: {str(population_size)}")
|
||||
learning_rate = mpatches.Patch(color="white", label=f"Learning rate: {str(learning_rate)}")
|
||||
epochs = mpatches.Patch(color="white", label=f"{str(exp_details)}")
|
||||
|
||||
if source_check == "summary":
|
||||
plt.legend(handles=[legend_population_size, learning_rate, epochs])
|
||||
plt.ylabel("No. of nets/run")
|
||||
plt.title("Summary: avg. amount of fixpoints/run")
|
||||
else:
|
||||
plt.legend(handles=[legend_population_size, learning_rate, epochs])
|
||||
plt.ylabel("Number of networks")
|
||||
plt.title("Fixpoint count")
|
||||
|
||||
plt.bar(range(len(fixpoint_counter)), list(fixpoint_counter.values()), align='center')
|
||||
plt.xticks(range(len(fixpoint_counter)), list(fixpoint_counter.keys()))
|
||||
|
||||
directory = Path(directory)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{str(population_size)}_nets_fixpoints_barchart.png"
|
||||
filepath = directory / filename
|
||||
plt.savefig(str(filepath))
|
||||
|
||||
plt.clf()
|
||||
|
||||
|
||||
def plot_3d(matrices_weights_history, directory: Union[str, Path], population_size, z_axis_legend,
|
||||
exp_name="experiment", is_trained="", batch_size=1, plot_pca_together=False, nets_array=None):
|
||||
""" Plotting the the weights of the nets in a 3d form using principal component analysis (PCA) """
|
||||
|
||||
fig = plt.figure()
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
|
||||
pca = PCA(n_components=2, whiten=True)
|
||||
ax = plt.axes(projection='3d')
|
||||
|
||||
if plot_pca_together:
|
||||
weight_histories = []
|
||||
start_times = []
|
||||
|
||||
for wh, st in matrices_weights_history:
|
||||
start_times.append(st)
|
||||
wm = np.array(wh)
|
||||
n, x, y = wm.shape
|
||||
wm = wm.reshape(n, x * y)
|
||||
weight_histories.append(wm)
|
||||
|
||||
weight_data = np.array(weight_histories)
|
||||
n, x, y = weight_data.shape
|
||||
weight_data = weight_data.reshape(n*x, y)
|
||||
|
||||
pca.fit(weight_data)
|
||||
weight_data_pca = pca.transform(weight_data)
|
||||
|
||||
for transformed_trajectory, start_time in zip(np.split(weight_data_pca, n), start_times):
|
||||
start_log_time = int(start_time / batch_size)
|
||||
xdata = transformed_trajectory[start_log_time:, 0]
|
||||
ydata = transformed_trajectory[start_log_time:, 1]
|
||||
zdata = np.arange(start_time, len(ydata)*batch_size+start_time, batch_size).tolist()
|
||||
ax.plot3D(xdata, ydata, zdata, label=f"net")
|
||||
ax.scatter(xdata, ydata, zdata, s=7)
|
||||
|
||||
else:
|
||||
loop_matrices_weights_history = tqdm(range(len(matrices_weights_history)))
|
||||
for i in loop_matrices_weights_history:
|
||||
loop_matrices_weights_history.set_description("Plotting weights 3D PCA %s" % i)
|
||||
|
||||
weight_matrix, start_time = matrices_weights_history[i]
|
||||
weight_matrix = np.array(weight_matrix)
|
||||
n, x, y = weight_matrix.shape
|
||||
weight_matrix = weight_matrix.reshape(n, x * y)
|
||||
|
||||
pca.fit(weight_matrix)
|
||||
weight_matrix_pca = pca.transform(weight_matrix)
|
||||
|
||||
xdata, ydata = [], []
|
||||
|
||||
start_log_time = int(start_time / 10)
|
||||
|
||||
for j in range(start_log_time, len(weight_matrix_pca)):
|
||||
xdata.append(weight_matrix_pca[j][0])
|
||||
ydata.append(weight_matrix_pca[j][1])
|
||||
zdata = np.arange(start_time, len(ydata)*batch_size+start_time, batch_size)
|
||||
|
||||
ax.plot3D(xdata, ydata, zdata, label=f"net {i}", c="b")
|
||||
if "parent" in nets_array[i].name:
|
||||
ax.scatter(np.asarray(xdata), np.asarray(ydata), zdata, s=3, c="b")
|
||||
else:
|
||||
ax.scatter(np.asarray(xdata), np.asarray(ydata), zdata, s=3)
|
||||
|
||||
#steps = mpatches.Patch(color="white", label=f"{z_axis_legend}: {len(matrices_weights_history)} steps")
|
||||
population_size = mpatches.Patch(color="white", label=f"Population: {population_size} networks")
|
||||
if False:
|
||||
if z_axis_legend == "Self-application":
|
||||
if is_trained == '_trained':
|
||||
trained = mpatches.Patch(color="white", label=f"Trained: true")
|
||||
else:
|
||||
trained = mpatches.Patch(color="white", label=f"Trained: false")
|
||||
ax.legend(handles=[population_size, trained])
|
||||
else:
|
||||
ax.legend(handles=[population_size])
|
||||
|
||||
ax.set_title(f"PCA Transformed Weight Trajectories")
|
||||
# ax.set_xlabel("PCA Transformed X-Axis")
|
||||
# ax.set_ylabel("PCA Transformed Y-Axis")
|
||||
ax.set_zlabel(f"Self Training Steps")
|
||||
|
||||
# FIXME: Replace this kind of operation with pathlib.Path() object interactions
|
||||
directory = Path(directory)
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"{exp_name}{is_trained}.png"
|
||||
filepath = directory / filename
|
||||
if filepath.exists():
|
||||
letters = string.ascii_lowercase
|
||||
random_letters = ''.join(random.choice(letters) for _ in range(5))
|
||||
plt.savefig(f"{filepath.stem}_{random_letters}.png")
|
||||
else:
|
||||
plt.savefig(str(filepath))
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_3d_self_train(nets_array: List, exp_name: str, directory: Union[str, Path], batch_size: int, plot_pca_together: bool):
|
||||
""" Plotting the evolution of the weights in a 3D space when doing self training. """
|
||||
|
||||
matrices_weights_history = []
|
||||
|
||||
loop_nets_array = tqdm(range(len(nets_array)))
|
||||
for i in loop_nets_array:
|
||||
loop_nets_array.set_description("Creating ST weights history %s" % i)
|
||||
|
||||
matrices_weights_history.append((nets_array[i].s_train_weights_history, nets_array[i].start_time))
|
||||
|
||||
z_axis_legend = "epochs"
|
||||
|
||||
return plot_3d(matrices_weights_history, directory, len(nets_array), z_axis_legend, exp_name, "", batch_size,
|
||||
plot_pca_together=plot_pca_together, nets_array=nets_array)
|
||||
|
||||
|
||||
def plot_3d_self_application(nets_array: List, exp_name: str, directory_name: Union[str, Path], batch_size: int) -> None:
|
||||
""" Plotting the evolution of the weights in a 3D space when doing self application. """
|
||||
|
||||
matrices_weights_history = []
|
||||
|
||||
loop_nets_array = tqdm(range(len(nets_array)))
|
||||
for i in loop_nets_array:
|
||||
loop_nets_array.set_description("Creating SA weights history %s" % i)
|
||||
|
||||
matrices_weights_history.append( (nets_array[i].s_application_weights_history, nets_array[i].start_time) )
|
||||
|
||||
if nets_array[i].trained:
|
||||
is_trained = "_trained"
|
||||
else:
|
||||
is_trained = "_not_trained"
|
||||
|
||||
# Fixme: Are the both following lines on the correct intendation? -> Value of "is_trained" changes multiple times!
|
||||
z_axis_legend = "epochs"
|
||||
plot_3d(matrices_weights_history, directory_name, len(nets_array), z_axis_legend, exp_name, is_trained, batch_size)
|
||||
|
||||
|
||||
def plot_3d_soup(nets_list, exp_name, directory: Union[str, Path]):
|
||||
""" Plotting the evolution of the weights in a 3D space for the soup environment. """
|
||||
|
||||
# This batch size is not relevant for soups. To not affect the number of epochs shown in the 3D plot,
|
||||
# will send forward the number "1" for batch size with the variable <irrelevant_batch_size>.
|
||||
irrelevant_batch_size = 1
|
||||
|
||||
# plot_3d_self_train(nets_list, exp_name, directory, irrelevant_batch_size, False)
|
||||
plot_3d_self_train(nets_list, exp_name, directory, 10, True)
|
||||
|
||||
|
||||
def line_chart_fixpoints(fixpoint_counters_history: list, epochs: int, ST_steps_between_SA: int,
|
||||
SA_steps, directory: Union[str, Path], population_size: int):
|
||||
""" Plotting the percentage of fixpoints after each iteration of SA & ST steps. """
|
||||
|
||||
fig = plt.figure()
|
||||
fig.set_figheight(10)
|
||||
fig.set_figwidth(12)
|
||||
|
||||
ST_steps_per_SA = np.arange(0, ST_steps_between_SA * epochs, ST_steps_between_SA).tolist()
|
||||
|
||||
legend_population_size = mpatches.Patch(color="white", label=f"No. of nets: {str(population_size)}")
|
||||
legend_SA_steps = mpatches.Patch(color="white", label=f"SA_steps: {str(SA_steps)}")
|
||||
legend_SA_and_ST_runs = mpatches.Patch(color="white", label=f"SA_and_ST_runs: {str(epochs)}")
|
||||
legend_ST_steps_between_SA = mpatches.Patch(color="white", label=f"ST_steps_between_SA: {str(ST_steps_between_SA)}")
|
||||
|
||||
plt.legend(handles=[legend_population_size, legend_SA_and_ST_runs, legend_SA_steps, legend_ST_steps_between_SA])
|
||||
plt.xlabel("Epochs")
|
||||
plt.ylabel("Percentage")
|
||||
plt.title("Percentage of fixpoints")
|
||||
|
||||
plt.plot(ST_steps_per_SA, fixpoint_counters_history, color="green", marker="o")
|
||||
|
||||
directory = Path(directory)
|
||||
filename = f"{str(population_size)}_nets_fixpoints_linechart.png"
|
||||
filepath = directory / filename
|
||||
plt.savefig(str(filepath))
|
||||
|
||||
plt.clf()
|
||||
|
||||
|
||||
def box_plot(data, directory: Union[str, Path], population_size):
|
||||
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 7))
|
||||
|
||||
# ax = fig.add_axes([0, 0, 1, 1])
|
||||
plt.title("Fixpoint variation")
|
||||
plt.xlabel("Amount of noise")
|
||||
plt.ylabel("Steps")
|
||||
|
||||
# data = numpy.array(data)
|
||||
# ax.boxplot(data)
|
||||
axs[1].boxplot(data)
|
||||
axs[1].set_title('Box plot')
|
||||
|
||||
directory = Path(directory)
|
||||
filename = f"{str(population_size)}_nets_fixpoints_barchart.png"
|
||||
filepath = directory / filename
|
||||
|
||||
plt.savefig(str(filepath))
|
||||
plt.clf()
|
||||
|
||||
|
||||
def write_file(text, directory: Union[str, Path]):
|
||||
directory = Path(directory)
|
||||
filepath = directory / 'experiment.txt'
|
||||
with filepath.open('w+') as f:
|
||||
f.write(text)
|
||||
f.close()
|