MetaNetworks Debugged
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
""" ----------------------------------------- Methods for summarizing the experiments ------------------------------------------ """
|
""" -------------------------------- Methods for summarizing the experiments --------------------------------- """
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from visualization import line_chart_fixpoints, bar_chart_fixpoints
|
from visualization import line_chart_fixpoints, bar_chart_fixpoints
|
||||||
@@ -52,7 +51,9 @@ def summary_fixpoint_percentage(runs, epochs, fixpoints_percentages, ST_steps, S
|
|||||||
line_chart_fixpoints(fixpoints_percentages, epochs, ST_steps, SA_steps, directory_name, population_size)
|
line_chart_fixpoints(fixpoints_percentages, epochs, ST_steps, SA_steps, directory_name, population_size)
|
||||||
|
|
||||||
|
|
||||||
""" --------------------------------------------------- Miscellaneous ---------------------------------------------------------- """
|
""" -------------------------------------------- Miscellaneous --------------------------------------------------- """
|
||||||
|
|
||||||
|
|
||||||
def check_folder(experiment_folder: str):
|
def check_folder(experiment_folder: str):
|
||||||
exp_path = Path('experiments') / experiment_folder
|
exp_path = Path('experiments') / experiment_folder
|
||||||
exp_path.mkdir(parents=True, exist_ok=True)
|
exp_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import pickle
|
import pickle
|
||||||
import time
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
import platform
|
import platform
|
||||||
@@ -7,7 +7,14 @@ import platform
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
if platform.node() != 'CarbonX':
|
from functionalities_test import test_for_fixpoints
|
||||||
|
|
||||||
|
if platform.node() == 'CarbonX':
|
||||||
|
debug = True
|
||||||
|
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||||
|
print("@ Warning, Debugging Config@!!!!!! @")
|
||||||
|
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||||
|
else:
|
||||||
debug = False
|
debug = False
|
||||||
try:
|
try:
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
@@ -20,8 +27,7 @@ if platform.node() != 'CarbonX':
|
|||||||
except NameError:
|
except NameError:
|
||||||
DIR = None
|
DIR = None
|
||||||
pass
|
pass
|
||||||
else:
|
|
||||||
debug = True
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -31,7 +37,7 @@ from torch import nn
|
|||||||
from torch.nn import Flatten
|
from torch.nn import Flatten
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
from torchvision.transforms import ToTensor, Compose
|
from torchvision.transforms import ToTensor, Compose, Resize
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from network import MetaNet
|
from network import MetaNet
|
||||||
@@ -69,7 +75,7 @@ class AddTaskDataset(Dataset):
|
|||||||
|
|
||||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||||
epoch_n = str(epoch_n)
|
epoch_n = str(epoch_n)
|
||||||
if final_model:
|
if not final_model:
|
||||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||||
else:
|
else:
|
||||||
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
|
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
|
||||||
@@ -84,32 +90,32 @@ def validate(checkpoint_path, ratio=0.1):
|
|||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
# initialize metric
|
# initialize metric
|
||||||
metric = torchmetrics.Accuracy()
|
validmetric = torchmetrics.Accuracy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False)
|
datas = MNIST(str(data_path), transform=utility_transforms, train=False)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
|
datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
|
||||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||||
|
|
||||||
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
||||||
n_samples = int(len(d) * ratio)
|
n_samples = int(len(d) * ratio)
|
||||||
|
|
||||||
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
|
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
|
||||||
for idx, (batch_x, batch_y) in enumerate(d):
|
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
|
||||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
|
||||||
y = model(batch_x)
|
y_valid = model(valid_batch_x)
|
||||||
|
|
||||||
# metric on current batch
|
# metric on current batch
|
||||||
acc = metric(y.cpu(), batch_y.cpu())
|
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
|
||||||
pbar.set_postfix_str(f'Acc: {acc}')
|
pbar.set_postfix_str(f'Acc: {acc}')
|
||||||
pbar.update()
|
pbar.update()
|
||||||
if idx == n_samples:
|
if idx == n_samples:
|
||||||
break
|
break
|
||||||
|
|
||||||
# metric on all batches using custom accumulation
|
# metric on all batches using custom accumulation
|
||||||
acc = metric.compute()
|
acc = validmetric.compute()
|
||||||
print(f"Accuracy on all data: {acc}")
|
tqdm.write(f"Avg. accuracy on all data: {acc}")
|
||||||
return acc
|
return acc
|
||||||
|
|
||||||
|
|
||||||
@@ -125,41 +131,39 @@ def plot_training_result(path_to_dataframe):
|
|||||||
df = pd.read_csv(path_to_dataframe, index_col=0)
|
df = pd.read_csv(path_to_dataframe, index_col=0)
|
||||||
|
|
||||||
fig, ax1 = plt.subplots() # initializes figure and plots
|
fig, ax1 = plt.subplots() # initializes figure and plots
|
||||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y axis.
|
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
|
||||||
|
|
||||||
# plots the first set of data, and sets it to ax1.
|
# plots the first set of data, and sets it to ax1.
|
||||||
data = df[df['Metric'] == 'BatchLoss']
|
data = df[df['Metric'] == 'BatchLoss']
|
||||||
# plots the second set, and sets to ax2.
|
# plots the second set, and sets to ax2.
|
||||||
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax2)
|
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue')
|
||||||
data = df[df['Metric'] == 'Test Accuracy']
|
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
|
||||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='red')
|
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True)
|
||||||
data = df[df['Metric'] == 'Train Accuracy']
|
|
||||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='green')
|
|
||||||
|
|
||||||
ax2.set(yscale='log')
|
ax1.set(yscale='log')
|
||||||
ax1.set_title('Training Lineplot')
|
ax1.set_title('Training Lineplot')
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
if debug:
|
if debug:
|
||||||
plt.show()
|
plt.show()
|
||||||
else:
|
else:
|
||||||
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'))
|
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
self_train = True
|
self_train = False
|
||||||
soup_interaction = True
|
training = False
|
||||||
training = True
|
plotting = False
|
||||||
plotting = True
|
particle_analysis = True
|
||||||
|
|
||||||
data_path = Path('data')
|
data_path = Path('data')
|
||||||
data_path.mkdir(exist_ok=True, parents=True)
|
data_path.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
run_path = Path('output') / 'intergrated_self_train'
|
run_path = Path('output') / 'mnist_test_half_size'
|
||||||
model_path = run_path / '0000_trained_model.zip'
|
model_path = run_path / '0000_trained_model.zip'
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
utility_transforms = Compose([ToTensor(), ToFloat(), Flatten(start_dim=0)])
|
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||||
@@ -179,6 +183,8 @@ if __name__ == '__main__':
|
|||||||
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
||||||
if is_validation_epoch:
|
if is_validation_epoch:
|
||||||
metric = torchmetrics.Accuracy()
|
metric = torchmetrics.Accuracy()
|
||||||
|
else:
|
||||||
|
metric = None
|
||||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||||
if self_train and is_self_train_epoch:
|
if self_train and is_self_train_epoch:
|
||||||
# Zero your gradients for every batch!
|
# Zero your gradients for every batch!
|
||||||
@@ -221,10 +227,19 @@ if __name__ == '__main__':
|
|||||||
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
|
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
|
||||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||||
Metric='Test Accuracy', Score=accuracy.item())
|
Metric='Test Accuracy', Score=accuracy.item())
|
||||||
train_store.loc[train_store.shape[0]] = validation_log
|
|
||||||
|
|
||||||
torch.save(metanet, model_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
train_store.loc[train_store.shape[0]] = validation_log
|
||||||
train_store.to_csv(run_path / 'train_store.csv')
|
train_store.to_csv(run_path / 'train_store.csv')
|
||||||
|
|
||||||
if plotting:
|
if plotting:
|
||||||
plot_training_result(run_path / 'train_store.csv')
|
plot_training_result(run_path / 'train_store.csv')
|
||||||
|
|
||||||
|
if particle_analysis:
|
||||||
|
model_path = next(run_path.glob('*.tp'))
|
||||||
|
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||||
|
analysis_dict = defaultdict(dict)
|
||||||
|
counter_dict = defaultdict(lambda: 0)
|
||||||
|
for particle in latest_model.particles:
|
||||||
|
analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged()
|
||||||
|
test_for_fixpoints(counter_dict, latest_model.particles)
|
||||||
|
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Tensor
|
|
||||||
from network import Net
|
from network import Net
|
||||||
|
|
||||||
|
|
||||||
@@ -9,7 +8,7 @@ def is_divergent(network: Net) -> bool:
|
|||||||
for i in network.input_weight_matrix():
|
for i in network.input_weight_matrix():
|
||||||
weight_value = i[0].item()
|
weight_value = i[0].item()
|
||||||
|
|
||||||
if np.isnan(weight_value).all() or np.isinf(weight_value).all():
|
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -25,7 +24,7 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
|
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||||
target_data = network.create_target_weights(network.input_weight_matrix().detach().numpy())
|
target_data = network.create_target_weights(network.input_weight_matrix().detach())
|
||||||
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
|
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
|
||||||
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
|
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
|
||||||
return result
|
return result
|
||||||
@@ -95,4 +94,4 @@ def test_status(net: Net) -> Net:
|
|||||||
else:
|
else:
|
||||||
net.is_fixpoint = "other_func"
|
net.is_fixpoint = "other_func"
|
||||||
|
|
||||||
return net
|
return net
|
||||||
|
@@ -6,7 +6,6 @@ import pickle
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from sklearn import preprocessing
|
|
||||||
|
|
||||||
from functionalities_test import is_identity_function, test_status
|
from functionalities_test import is_identity_function, test_status
|
||||||
from journal_basins import SpawnExperiment, mean_invariate_manhattan_distance
|
from journal_basins import SpawnExperiment, mean_invariate_manhattan_distance
|
||||||
@@ -22,8 +21,11 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
|||||||
number_clones = number_clones or self.nr_clones
|
number_clones = number_clones or self.nr_clones
|
||||||
|
|
||||||
df = pd.DataFrame(
|
df = pd.DataFrame(
|
||||||
columns=['clone', 'parent', 'parent2', 'MAE_pre', 'MAE_post', 'MSE_pre', 'MSE_post', 'MIM_pre', 'MIM_post', 'noise',
|
columns=['clone', 'parent', 'parent2',
|
||||||
'status_pst'])
|
'MAE_pre', 'MAE_post',
|
||||||
|
'MSE_pre', 'MSE_post',
|
||||||
|
'MIM_pre', 'MIM_post',
|
||||||
|
'noise', 'status_pst'])
|
||||||
|
|
||||||
# For every initial net {i} after populating (that is fixpoint after first epoch);
|
# For every initial net {i} after populating (that is fixpoint after first epoch);
|
||||||
# parent = self.parents[0]
|
# parent = self.parents[0]
|
||||||
@@ -38,15 +40,15 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
|||||||
# to see full trajectory (but the clones will be very hard to see).
|
# to see full trajectory (but the clones will be very hard to see).
|
||||||
# Make one target to compare distances to clones later when they have trained.
|
# Make one target to compare distances to clones later when they have trained.
|
||||||
net1.start_time = self.ST_steps - 150
|
net1.start_time = self.ST_steps - 150
|
||||||
net1_input_data = net1.input_weight_matrix()
|
net1_input_data = net1.input_weight_matrix().detach()
|
||||||
net1_target_data = net1.create_target_weights(net1_input_data)
|
net1_target_data = net1.create_target_weights(net1_input_data).detach()
|
||||||
|
|
||||||
net2.start_time = self.ST_steps - 150
|
net2.start_time = self.ST_steps - 150
|
||||||
net2_input_data = net2.input_weight_matrix()
|
net2_input_data = net2.input_weight_matrix().detach()
|
||||||
net2_target_data = net2.create_target_weights(net2_input_data)
|
net2_target_data = net2.create_target_weights(net2_input_data).detach()
|
||||||
|
|
||||||
if is_identity_function(net1) and is_identity_function(net2):
|
if is_identity_function(net1) and is_identity_function(net2):
|
||||||
# if True:
|
# if True:
|
||||||
# Clone the fixpoint x times and add (+-)self.noise to weight-sets randomly;
|
# Clone the fixpoint x times and add (+-)self.noise to weight-sets randomly;
|
||||||
# To plot clones starting after first epoch (z=ST_steps), set that as start_time!
|
# To plot clones starting after first epoch (z=ST_steps), set that as start_time!
|
||||||
# To make sure PCA will plot the same trajectory up until this point, we clone the
|
# To make sure PCA will plot the same trajectory up until this point, we clone the
|
||||||
@@ -64,7 +66,7 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
|||||||
clone.number_trained = copy.deepcopy(net1.number_trained)
|
clone.number_trained = copy.deepcopy(net1.number_trained)
|
||||||
|
|
||||||
# Pre Training distances (after noise application of course)
|
# Pre Training distances (after noise application of course)
|
||||||
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix())
|
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
|
||||||
MAE_pre = MAE(net1_target_data, clone_pre_weights)
|
MAE_pre = MAE(net1_target_data, clone_pre_weights)
|
||||||
MSE_pre = MSE(net1_target_data, clone_pre_weights)
|
MSE_pre = MSE(net1_target_data, clone_pre_weights)
|
||||||
MIM_pre = mean_invariate_manhattan_distance(net1_target_data, clone_pre_weights)
|
MIM_pre = mean_invariate_manhattan_distance(net1_target_data, clone_pre_weights)
|
||||||
@@ -78,10 +80,15 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("Ran into nan in 'in beetween weights' array.")
|
print("Ran into nan in 'in beetween weights' array.")
|
||||||
|
df.loc[len(df)] = [j, net1.name, net2.name,
|
||||||
|
MAE_pre, 0,
|
||||||
|
MSE_pre, 0,
|
||||||
|
MIM_pre, 0,
|
||||||
|
self.noise, clone.is_fixpoint]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Post Training distances for comparison
|
# Post Training distances for comparison
|
||||||
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())
|
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
|
||||||
MAE_post = MAE(net1_target_data, clone_post_weights)
|
MAE_post = MAE(net1_target_data, clone_post_weights)
|
||||||
MSE_post = MSE(net1_target_data, clone_post_weights)
|
MSE_post = MSE(net1_target_data, clone_post_weights)
|
||||||
MIM_post = mean_invariate_manhattan_distance(net1_target_data, clone_post_weights)
|
MIM_post = mean_invariate_manhattan_distance(net1_target_data, clone_post_weights)
|
||||||
@@ -95,16 +102,23 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
|||||||
f"\nMIM({net1.name},{j}): {MIM_post}\n")
|
f"\nMIM({net1.name},{j}): {MIM_post}\n")
|
||||||
self.nets.append(clone)
|
self.nets.append(clone)
|
||||||
|
|
||||||
df.loc[len(df)] = [j, net1.name, net2.name, MAE_pre, MAE_post, MSE_pre, MSE_post, MIM_pre, MIM_post,
|
df.loc[len(df)] = [j, net1.name, net2.name,
|
||||||
self.noise, clone.is_fixpoint]
|
MAE_pre, MAE_post,
|
||||||
|
MSE_pre, MSE_post,
|
||||||
|
MIM_pre, MIM_post,
|
||||||
|
self.noise, clone.is_fixpoint]
|
||||||
|
|
||||||
for net1, net2 in pairwise_net_list:
|
for net1, net2 in pairwise_net_list:
|
||||||
value = 'MAE'
|
try:
|
||||||
c_selector = [f'{value}_pre', f'{value}_post']
|
value = 'MAE'
|
||||||
values = df.loc[(df['parent'] == net1.name) & (df['parent2'] == net2.name)][c_selector]
|
c_selector = [f'{value}_pre', f'{value}_post']
|
||||||
this_min, this_max = values.values.min(), values.values.max()
|
values = df.loc[(df['parent'] == net1.name) & (df['parent2'] == net2.name)][c_selector]
|
||||||
df.loc[(df['parent'] == net1.name) &
|
this_min, this_max = values.values.min(), values.values.max()
|
||||||
(df['parent2'] == net2.name), c_selector] = (values - this_min) / (this_max - this_min)
|
df.loc[(df['parent'] == net1.name) &
|
||||||
|
(df['parent2'] == net2.name), c_selector] = (values - this_min) / (this_max - this_min)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
for parent in self.parents:
|
for parent in self.parents:
|
||||||
for _ in range(self.epochs - 1):
|
for _ in range(self.epochs - 1):
|
||||||
for _ in range(self.ST_steps):
|
for _ in range(self.ST_steps):
|
||||||
@@ -148,7 +162,8 @@ if __name__ == '__main__':
|
|||||||
df = exp.df
|
df = exp.df
|
||||||
|
|
||||||
directory = Path('output') / 'spawn_basin' / f'{ST_name_hash}' / 'linage'
|
directory = Path('output') / 'spawn_basin' / f'{ST_name_hash}' / 'linage'
|
||||||
pickle.dump(exp, open(f"{directory}/experiment_pickle_{ST_name_hash}.p", "wb"))
|
with (directory / f"experiment_pickle_{ST_name_hash}.p").open('wb') as f:
|
||||||
|
pickle.dump(exp, f)
|
||||||
print(f"\nSaved experiment to {directory}.")
|
print(f"\nSaved experiment to {directory}.")
|
||||||
|
|
||||||
# Boxplot with counts of nr_fixpoints, nr_other, nr_etc. on y-axis
|
# Boxplot with counts of nr_fixpoints, nr_other, nr_etc. on y-axis
|
||||||
@@ -183,6 +198,6 @@ if __name__ == '__main__':
|
|||||||
# else:
|
# else:
|
||||||
# label.set_visible(False)
|
# label.set_visible(False)
|
||||||
|
|
||||||
filepath = exp.directory / 'mim_dist_plot.png'
|
filepath = exp.directory / 'mim_dist_plot.pdf'
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(filepath)
|
plt.savefig(filepath, dpi=600, format='pdf', bbox_inches='tight')
|
||||||
|
@@ -26,11 +26,14 @@ def l1(tup):
|
|||||||
|
|
||||||
|
|
||||||
def mean_invariate_manhattan_distance(x, y):
|
def mean_invariate_manhattan_distance(x, y):
|
||||||
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
|
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
|
||||||
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
|
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
|
||||||
# Idea was to find weight sets that have same values but just in different positions, that would
|
# Idea was to find weight sets that have same values but just in different positions, that would
|
||||||
# make this distance 0.
|
# make this distance 0.
|
||||||
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
|
try:
|
||||||
|
return np.mean(list(map(l1, zip(sorted(x.detach().numpy()), sorted(y.detach().numpy())))))
|
||||||
|
except AttributeError:
|
||||||
|
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
|
||||||
|
|
||||||
|
|
||||||
def distance_matrix(nets, distance="MIM", print_it=True):
|
def distance_matrix(nets, distance="MIM", print_it=True):
|
||||||
|
@@ -8,6 +8,7 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.metrics import mean_absolute_error as MAE
|
from sklearn.metrics import mean_absolute_error as MAE
|
||||||
from sklearn.metrics import mean_squared_error as MSE
|
from sklearn.metrics import mean_squared_error as MSE
|
||||||
@@ -16,6 +17,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from functionalities_test import is_identity_function, test_status, is_zero_fixpoint, is_divergent, \
|
from functionalities_test import is_identity_function, test_status, is_zero_fixpoint, is_divergent, \
|
||||||
is_secondary_fixpoint
|
is_secondary_fixpoint
|
||||||
|
from journal_basins import mean_invariate_manhattan_distance
|
||||||
from network import Net
|
from network import Net
|
||||||
from visualization import plot_loss, plot_3d_soup
|
from visualization import plot_loss, plot_3d_soup
|
||||||
|
|
||||||
@@ -25,14 +27,6 @@ def l1(tup):
|
|||||||
return abs(a - b)
|
return abs(a - b)
|
||||||
|
|
||||||
|
|
||||||
def mean_invariate_manhattan_distance(x, y):
|
|
||||||
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
|
|
||||||
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
|
|
||||||
# Idea was to find weight sets that have same values but just in different positions, that would
|
|
||||||
# make this distance 0.
|
|
||||||
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
|
|
||||||
|
|
||||||
|
|
||||||
def distance_matrix(nets, distance="MIM", print_it=True):
|
def distance_matrix(nets, distance="MIM", print_it=True):
|
||||||
matrix = [[0 for _ in range(len(nets))] for _ in range(len(nets))]
|
matrix = [[0 for _ in range(len(nets))] for _ in range(len(nets))]
|
||||||
for net in range(len(nets)):
|
for net in range(len(nets)):
|
||||||
|
102
network.py
102
network.py
@@ -34,17 +34,21 @@ class Net(nn.Module):
|
|||||||
def are_weights_diverged(network_weights):
|
def are_weights_diverged(network_weights):
|
||||||
""" Testing if the weights are eiter converging to infinity or -infinity. """
|
""" Testing if the weights are eiter converging to infinity or -infinity. """
|
||||||
|
|
||||||
for layer_id, layer in enumerate(network_weights):
|
# Slow and shitty:
|
||||||
for cell_id, cell in enumerate(layer):
|
# for layer_id, layer in enumerate(network_weights):
|
||||||
for weight_id, weight in enumerate(cell):
|
# for cell_id, cell in enumerate(layer):
|
||||||
if torch.isnan(weight):
|
# for weight_id, weight in enumerate(cell):
|
||||||
return True
|
# if torch.isnan(weight):
|
||||||
if torch.isinf(weight):
|
# return True
|
||||||
return True
|
# if torch.isinf(weight):
|
||||||
return False
|
# return True
|
||||||
|
# return False
|
||||||
|
# Fast and modern:
|
||||||
|
return any(x.isnan.any() or x.isinf().any() for x in network_weights.parameters)
|
||||||
|
|
||||||
def apply_weights(self, new_weights: Tensor):
|
def apply_weights(self, new_weights: Tensor):
|
||||||
""" Changing the weights of a network to new given values. """
|
""" Changing the weights of a network to new given values. """
|
||||||
|
# TODO: Change this to 'parameters' version
|
||||||
i = 0
|
i = 0
|
||||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||||
@@ -101,15 +105,17 @@ class Net(nn.Module):
|
|||||||
# Cell Enumeration
|
# Cell Enumeration
|
||||||
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
|
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
|
||||||
# Weight Enumeration within the Cells
|
# Weight Enumeration within the Cells
|
||||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
|
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1),
|
||||||
|
*(torch.full((x.numel(), 1), 0, device=d) for _ in range(self.input_size-4))
|
||||||
), dim=1)
|
), dim=1)
|
||||||
)
|
)
|
||||||
# Finalize
|
# Finalize
|
||||||
weight_matrix = torch.cat(weight_matrix).float()
|
weight_matrix = torch.cat(weight_matrix).float()
|
||||||
|
|
||||||
# Normalize all along the 1 dimensions
|
# Normalize 1,2,3 column of dim 1
|
||||||
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
|
last_pos_idx = self.input_size - 4
|
||||||
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
|
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||||
|
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8
|
||||||
|
|
||||||
# computations
|
# computations
|
||||||
# create a mask where pos is 0 if it is to be replaced
|
# create a mask where pos is 0 if it is to be replaced
|
||||||
@@ -117,7 +123,7 @@ class Net(nn.Module):
|
|||||||
mask[:, 0] = 0
|
mask[:, 0] = 0
|
||||||
|
|
||||||
self._weight_pos_enc_and_mask = weight_matrix, mask
|
self._weight_pos_enc_and_mask = weight_matrix, mask
|
||||||
return self._weight_pos_enc_and_mask
|
return tuple(x.clone() for x in self._weight_pos_enc_and_mask)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@@ -125,6 +131,7 @@ class Net(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def normalize(self, value, norm):
|
def normalize(self, value, norm):
|
||||||
|
raise NotImplementedError
|
||||||
# FIXME, This is bullshit, the code does not do what the docstring explains
|
# FIXME, This is bullshit, the code does not do what the docstring explains
|
||||||
# Obsolete now
|
# Obsolete now
|
||||||
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
|
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
|
||||||
@@ -138,7 +145,7 @@ class Net(nn.Module):
|
|||||||
""" Calculating the input tensor formed from the weights of the net """
|
""" Calculating the input tensor formed from the weights of the net """
|
||||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||||
pos_enc, mask = self._weight_pos_enc
|
pos_enc, mask = self._weight_pos_enc
|
||||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
|
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
|
||||||
return weight_matrix
|
return weight_matrix
|
||||||
|
|
||||||
def self_train(self,
|
def self_train(self,
|
||||||
@@ -283,33 +290,50 @@ class SecondaryNet(Net):
|
|||||||
|
|
||||||
|
|
||||||
class MetaCell(nn.Module):
|
class MetaCell(nn.Module):
|
||||||
def __init__(self, name, interface, residual_skip=True):
|
def __init__(self, name, interface):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.residual_skip = residual_skip
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.interface = interface
|
self.interface = interface
|
||||||
self.weight_interface = 4
|
self.weight_interface = 5
|
||||||
self.net_hidden_size = 4
|
self.net_hidden_size = 4
|
||||||
self.net_ouput_size = 1
|
self.net_ouput_size = 1
|
||||||
self.meta_weight_list = nn.ModuleList()
|
self.meta_weight_list = nn.ModuleList()
|
||||||
self.meta_weight_list.extend(
|
self.meta_weight_list.extend(
|
||||||
[Net(self.weight_interface, self.net_hidden_size,
|
[Net(self.weight_interface, self.net_hidden_size,
|
||||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
|
||||||
) for weight_idx in range(self.interface)]
|
) for weight_idx in range(self.interface)]
|
||||||
)
|
)
|
||||||
|
self.__bed_mask = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _bed_mask(self):
|
||||||
|
if self.__bed_mask is None:
|
||||||
|
d = next(self.parameters()).device
|
||||||
|
embedding = torch.zeros(1, self.weight_interface, device=d)
|
||||||
|
|
||||||
|
# computations
|
||||||
|
# create a mask where pos is 0 if it is to be replaced
|
||||||
|
mask = torch.ones_like(embedding)
|
||||||
|
mask[:, -1] = 0
|
||||||
|
|
||||||
|
self.__bed_mask = embedding, mask
|
||||||
|
return tuple(x.clone() for x in self.__bed_mask)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
xs = [torch.hstack(
|
embedding, mask = self._bed_mask
|
||||||
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
|
expanded_mask = mask.expand(*x.shape, embedding.shape[-1])
|
||||||
)
|
embedding = embedding.repeat(*x.shape, 1)
|
||||||
for idx in range(len(self.meta_weight_list))]
|
|
||||||
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
|
||||||
|
|
||||||
if self.residual_skip:
|
# Row-wise
|
||||||
tensor += x
|
# xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1]).swapdims(0, 1)
|
||||||
|
# Column-wise
|
||||||
|
xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1])
|
||||||
|
xs = embedding * expanded_mask + xs * (1 - expanded_mask)
|
||||||
|
# ToDo Speed this up!
|
||||||
|
tensor = torch.hstack([meta_weight(xs[:, idx, :]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||||
|
|
||||||
result = torch.sum(tensor, dim=-1, keepdim=True)
|
tensor = torch.sum(tensor, dim=-1, keepdim=True)
|
||||||
return result
|
return tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def particles(self):
|
def particles(self):
|
||||||
@@ -317,21 +341,27 @@ class MetaCell(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MetaLayer(nn.Module):
|
class MetaLayer(nn.Module):
|
||||||
def __init__(self, name, interface=4, width=4):
|
def __init__(self, name, interface=4, width=4, residual_skip=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.residual_skip = residual_skip
|
||||||
self.name = name
|
self.name = name
|
||||||
self.interface = interface
|
self.interface = interface
|
||||||
self.width = width
|
self.width = width
|
||||||
|
|
||||||
self.meta_cell_list = nn.ModuleList()
|
self.meta_cell_list = nn.ModuleList()
|
||||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
|
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}',
|
||||||
interface=interface
|
interface=interface
|
||||||
) for cell_idx in range(self.width)]
|
) for cell_idx in range(self.width)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
cell_results = []
|
||||||
return result
|
for metacell in self.meta_cell_list:
|
||||||
|
cell_results.append(metacell(x))
|
||||||
|
tensor = torch.hstack(cell_results)
|
||||||
|
if self.residual_skip and x.shape == tensor.shape:
|
||||||
|
tensor += x
|
||||||
|
return tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def particles(self):
|
def particles(self):
|
||||||
@@ -349,15 +379,15 @@ class MetaNet(nn.Module):
|
|||||||
self.depth = depth
|
self.depth = depth
|
||||||
|
|
||||||
self._meta_layer_list = nn.ModuleList()
|
self._meta_layer_list = nn.ModuleList()
|
||||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
|
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
|
||||||
interface=self.interface,
|
interface=self.interface,
|
||||||
width=self.width)
|
width=self.width)
|
||||||
)
|
)
|
||||||
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
|
||||||
interface=self.width, width=self.width
|
interface=self.width, width=self.width
|
||||||
) for layer_idx in range(self.depth - 2)]
|
) for layer_idx in range(self.depth - 2)]
|
||||||
)
|
)
|
||||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
|
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
|
||||||
interface=self.width, width=self.out)
|
interface=self.width, width=self.out)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -383,9 +413,9 @@ class MetaNet(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
|
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
|
||||||
next(metanet.particles).input_weight_matrix()
|
next(metanet.particles).input_weight_matrix()
|
||||||
metanet(torch.ones((5, 2)))
|
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
|
||||||
a = metanet.particles
|
a = metanet.particles
|
||||||
print('Test')
|
print('Test')
|
||||||
print('Test')
|
print('Test')
|
||||||
|
@@ -2,10 +2,13 @@ torch~=1.8.1+cpu
|
|||||||
tqdm~=4.60.0
|
tqdm~=4.60.0
|
||||||
numpy~=1.20.3
|
numpy~=1.20.3
|
||||||
matplotlib~=3.4.2
|
matplotlib~=3.4.2
|
||||||
sklearn
|
sklearn~=0.0
|
||||||
scipy
|
scipy
|
||||||
tabulate~=0.8.9
|
tabulate~=0.8.9
|
||||||
|
|
||||||
scikit-learn~=0.24.2
|
scikit-learn~=0.24.2
|
||||||
pandas~=1.2.4
|
pandas~=1.2.4
|
||||||
seaborn~=0.11.1
|
seaborn~=0.11.1
|
||||||
|
future~=0.18.2
|
||||||
|
torchmetrics~=0.7.0
|
||||||
|
torchvision~=0.9.1+cpu
|
Reference in New Issue
Block a user