MetaNetworks Debugged
This commit is contained in:
parent
49c0d8a621
commit
246d825bb4
@ -1,5 +1,4 @@
|
||||
""" ----------------------------------------- Methods for summarizing the experiments ------------------------------------------ """
|
||||
import os
|
||||
""" -------------------------------- Methods for summarizing the experiments --------------------------------- """
|
||||
from pathlib import Path
|
||||
|
||||
from visualization import line_chart_fixpoints, bar_chart_fixpoints
|
||||
@ -52,7 +51,9 @@ def summary_fixpoint_percentage(runs, epochs, fixpoints_percentages, ST_steps, S
|
||||
line_chart_fixpoints(fixpoints_percentages, epochs, ST_steps, SA_steps, directory_name, population_size)
|
||||
|
||||
|
||||
""" --------------------------------------------------- Miscellaneous ---------------------------------------------------------- """
|
||||
""" -------------------------------------------- Miscellaneous --------------------------------------------------- """
|
||||
|
||||
|
||||
def check_folder(experiment_folder: str):
|
||||
exp_path = Path('experiments') / experiment_folder
|
||||
exp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import pickle
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import platform
|
||||
@ -7,7 +7,14 @@ import platform
|
||||
import pandas as pd
|
||||
import torchmetrics
|
||||
|
||||
if platform.node() != 'CarbonX':
|
||||
from functionalities_test import test_for_fixpoints
|
||||
|
||||
if platform.node() == 'CarbonX':
|
||||
debug = True
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
print("@ Warning, Debugging Config@!!!!!! @")
|
||||
print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@")
|
||||
else:
|
||||
debug = False
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
@ -20,8 +27,7 @@ if platform.node() != 'CarbonX':
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
else:
|
||||
debug = True
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -31,7 +37,7 @@ from torch import nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from network import MetaNet
|
||||
@ -69,7 +75,7 @@ class AddTaskDataset(Dataset):
|
||||
|
||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
epoch_n = str(epoch_n)
|
||||
if final_model:
|
||||
if not final_model:
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
|
||||
@ -84,32 +90,32 @@ def validate(checkpoint_path, ratio=0.1):
|
||||
import torchmetrics
|
||||
|
||||
# initialize metric
|
||||
metric = torchmetrics.Accuracy()
|
||||
validmetric = torchmetrics.Accuracy()
|
||||
|
||||
try:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False)
|
||||
datas = MNIST(str(data_path), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
datas = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
|
||||
valid_d = DataLoader(datas, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
||||
n_samples = int(len(d) * ratio)
|
||||
|
||||
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
|
||||
for idx, (batch_x, batch_y) in enumerate(d):
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = model(batch_x)
|
||||
for idx, (valid_batch_x, valid_batch_y) in enumerate(valid_d):
|
||||
valid_batch_x, valid_batch_y = valid_batch_x.to(DEVICE), valid_batch_y.to(DEVICE)
|
||||
y_valid = model(valid_batch_x)
|
||||
|
||||
# metric on current batch
|
||||
acc = metric(y.cpu(), batch_y.cpu())
|
||||
acc = validmetric(y_valid.cpu(), valid_batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
if idx == n_samples:
|
||||
break
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
acc = metric.compute()
|
||||
print(f"Accuracy on all data: {acc}")
|
||||
acc = validmetric.compute()
|
||||
tqdm.write(f"Avg. accuracy on all data: {acc}")
|
||||
return acc
|
||||
|
||||
|
||||
@ -125,41 +131,39 @@ def plot_training_result(path_to_dataframe):
|
||||
df = pd.read_csv(path_to_dataframe, index_col=0)
|
||||
|
||||
fig, ax1 = plt.subplots() # initializes figure and plots
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y axis.
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y-axis.
|
||||
|
||||
# plots the first set of data, and sets it to ax1.
|
||||
data = df[df['Metric'] == 'BatchLoss']
|
||||
# plots the second set, and sets to ax2.
|
||||
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax2)
|
||||
data = df[df['Metric'] == 'Test Accuracy']
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='red')
|
||||
data = df[df['Metric'] == 'Train Accuracy']
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='green')
|
||||
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax1, color='blue')
|
||||
data = df[(df['Metric'] == 'Test Accuracy') | (df['Metric'] == 'Train Accuracy')]
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', hue='Metric', legend=True)
|
||||
|
||||
ax2.set(yscale='log')
|
||||
ax1.set(yscale='log')
|
||||
ax1.set_title('Training Lineplot')
|
||||
plt.tight_layout()
|
||||
if debug:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'))
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'), dpi=300)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
self_train = True
|
||||
soup_interaction = True
|
||||
training = True
|
||||
plotting = True
|
||||
self_train = False
|
||||
training = False
|
||||
plotting = False
|
||||
particle_analysis = True
|
||||
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
run_path = Path('output') / 'intergrated_self_train'
|
||||
run_path = Path('output') / 'mnist_test_half_size'
|
||||
model_path = run_path / '0000_trained_model.zip'
|
||||
|
||||
if training:
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Flatten(start_dim=0)])
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])
|
||||
|
||||
try:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||
@ -179,6 +183,8 @@ if __name__ == '__main__':
|
||||
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
||||
if is_validation_epoch:
|
||||
metric = torchmetrics.Accuracy()
|
||||
else:
|
||||
metric = None
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||
if self_train and is_self_train_epoch:
|
||||
# Zero your gradients for every batch!
|
||||
@ -221,10 +227,19 @@ if __name__ == '__main__':
|
||||
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric='Test Accuracy', Score=accuracy.item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
torch.save(metanet, model_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
train_store.to_csv(run_path / 'train_store.csv')
|
||||
|
||||
if plotting:
|
||||
plot_training_result(run_path / 'train_store.csv')
|
||||
|
||||
if particle_analysis:
|
||||
model_path = next(run_path.glob('*.tp'))
|
||||
latest_model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
analysis_dict = defaultdict(dict)
|
||||
counter_dict = defaultdict(lambda: 0)
|
||||
for particle in latest_model.particles:
|
||||
analysis_dict[particle.name]['is_diverged'] = particle.are_weights_diverged()
|
||||
test_for_fixpoints(counter_dict, latest_model.particles)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
from network import Net
|
||||
|
||||
|
||||
@ -9,7 +8,7 @@ def is_divergent(network: Net) -> bool:
|
||||
for i in network.input_weight_matrix():
|
||||
weight_value = i[0].item()
|
||||
|
||||
if np.isnan(weight_value).all() or np.isinf(weight_value).all():
|
||||
if np.isnan(weight_value).any() or np.isinf(weight_value).any():
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -25,7 +24,7 @@ def is_identity_function(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
|
||||
|
||||
def is_zero_fixpoint(network: Net, epsilon=pow(10, -5)) -> bool:
|
||||
target_data = network.create_target_weights(network.input_weight_matrix().detach().numpy())
|
||||
target_data = network.create_target_weights(network.input_weight_matrix().detach())
|
||||
result = np.allclose(target_data, np.zeros_like(target_data), rtol=0, atol=epsilon)
|
||||
# result = bool(len(np.nonzero(network.create_target_weights(network.input_weight_matrix()))))
|
||||
return result
|
||||
@ -95,4 +94,4 @@ def test_status(net: Net) -> Net:
|
||||
else:
|
||||
net.is_fixpoint = "other_func"
|
||||
|
||||
return net
|
||||
return net
|
||||
|
@ -6,7 +6,6 @@ import pickle
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn import preprocessing
|
||||
|
||||
from functionalities_test import is_identity_function, test_status
|
||||
from journal_basins import SpawnExperiment, mean_invariate_manhattan_distance
|
||||
@ -22,8 +21,11 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
||||
number_clones = number_clones or self.nr_clones
|
||||
|
||||
df = pd.DataFrame(
|
||||
columns=['clone', 'parent', 'parent2', 'MAE_pre', 'MAE_post', 'MSE_pre', 'MSE_post', 'MIM_pre', 'MIM_post', 'noise',
|
||||
'status_pst'])
|
||||
columns=['clone', 'parent', 'parent2',
|
||||
'MAE_pre', 'MAE_post',
|
||||
'MSE_pre', 'MSE_post',
|
||||
'MIM_pre', 'MIM_post',
|
||||
'noise', 'status_pst'])
|
||||
|
||||
# For every initial net {i} after populating (that is fixpoint after first epoch);
|
||||
# parent = self.parents[0]
|
||||
@ -38,15 +40,15 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
||||
# to see full trajectory (but the clones will be very hard to see).
|
||||
# Make one target to compare distances to clones later when they have trained.
|
||||
net1.start_time = self.ST_steps - 150
|
||||
net1_input_data = net1.input_weight_matrix()
|
||||
net1_target_data = net1.create_target_weights(net1_input_data)
|
||||
net1_input_data = net1.input_weight_matrix().detach()
|
||||
net1_target_data = net1.create_target_weights(net1_input_data).detach()
|
||||
|
||||
net2.start_time = self.ST_steps - 150
|
||||
net2_input_data = net2.input_weight_matrix()
|
||||
net2_target_data = net2.create_target_weights(net2_input_data)
|
||||
net2_input_data = net2.input_weight_matrix().detach()
|
||||
net2_target_data = net2.create_target_weights(net2_input_data).detach()
|
||||
|
||||
if is_identity_function(net1) and is_identity_function(net2):
|
||||
# if True:
|
||||
# if True:
|
||||
# Clone the fixpoint x times and add (+-)self.noise to weight-sets randomly;
|
||||
# To plot clones starting after first epoch (z=ST_steps), set that as start_time!
|
||||
# To make sure PCA will plot the same trajectory up until this point, we clone the
|
||||
@ -64,7 +66,7 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
||||
clone.number_trained = copy.deepcopy(net1.number_trained)
|
||||
|
||||
# Pre Training distances (after noise application of course)
|
||||
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
|
||||
MAE_pre = MAE(net1_target_data, clone_pre_weights)
|
||||
MSE_pre = MSE(net1_target_data, clone_pre_weights)
|
||||
MIM_pre = mean_invariate_manhattan_distance(net1_target_data, clone_pre_weights)
|
||||
@ -78,10 +80,15 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
print("Ran into nan in 'in beetween weights' array.")
|
||||
df.loc[len(df)] = [j, net1.name, net2.name,
|
||||
MAE_pre, 0,
|
||||
MSE_pre, 0,
|
||||
MIM_pre, 0,
|
||||
self.noise, clone.is_fixpoint]
|
||||
continue
|
||||
|
||||
# Post Training distances for comparison
|
||||
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix()).detach()
|
||||
MAE_post = MAE(net1_target_data, clone_post_weights)
|
||||
MSE_post = MSE(net1_target_data, clone_post_weights)
|
||||
MIM_post = mean_invariate_manhattan_distance(net1_target_data, clone_post_weights)
|
||||
@ -95,16 +102,23 @@ class SpawnLinspaceExperiment(SpawnExperiment):
|
||||
f"\nMIM({net1.name},{j}): {MIM_post}\n")
|
||||
self.nets.append(clone)
|
||||
|
||||
df.loc[len(df)] = [j, net1.name, net2.name, MAE_pre, MAE_post, MSE_pre, MSE_post, MIM_pre, MIM_post,
|
||||
self.noise, clone.is_fixpoint]
|
||||
df.loc[len(df)] = [j, net1.name, net2.name,
|
||||
MAE_pre, MAE_post,
|
||||
MSE_pre, MSE_post,
|
||||
MIM_pre, MIM_post,
|
||||
self.noise, clone.is_fixpoint]
|
||||
|
||||
for net1, net2 in pairwise_net_list:
|
||||
value = 'MAE'
|
||||
c_selector = [f'{value}_pre', f'{value}_post']
|
||||
values = df.loc[(df['parent'] == net1.name) & (df['parent2'] == net2.name)][c_selector]
|
||||
this_min, this_max = values.values.min(), values.values.max()
|
||||
df.loc[(df['parent'] == net1.name) &
|
||||
(df['parent2'] == net2.name), c_selector] = (values - this_min) / (this_max - this_min)
|
||||
try:
|
||||
value = 'MAE'
|
||||
c_selector = [f'{value}_pre', f'{value}_post']
|
||||
values = df.loc[(df['parent'] == net1.name) & (df['parent2'] == net2.name)][c_selector]
|
||||
this_min, this_max = values.values.min(), values.values.max()
|
||||
df.loc[(df['parent'] == net1.name) &
|
||||
(df['parent2'] == net2.name), c_selector] = (values - this_min) / (this_max - this_min)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for parent in self.parents:
|
||||
for _ in range(self.epochs - 1):
|
||||
for _ in range(self.ST_steps):
|
||||
@ -148,7 +162,8 @@ if __name__ == '__main__':
|
||||
df = exp.df
|
||||
|
||||
directory = Path('output') / 'spawn_basin' / f'{ST_name_hash}' / 'linage'
|
||||
pickle.dump(exp, open(f"{directory}/experiment_pickle_{ST_name_hash}.p", "wb"))
|
||||
with (directory / f"experiment_pickle_{ST_name_hash}.p").open('wb') as f:
|
||||
pickle.dump(exp, f)
|
||||
print(f"\nSaved experiment to {directory}.")
|
||||
|
||||
# Boxplot with counts of nr_fixpoints, nr_other, nr_etc. on y-axis
|
||||
@ -183,6 +198,6 @@ if __name__ == '__main__':
|
||||
# else:
|
||||
# label.set_visible(False)
|
||||
|
||||
filepath = exp.directory / 'mim_dist_plot.png'
|
||||
filepath = exp.directory / 'mim_dist_plot.pdf'
|
||||
plt.tight_layout()
|
||||
plt.savefig(filepath)
|
||||
plt.savefig(filepath, dpi=600, format='pdf', bbox_inches='tight')
|
||||
|
@ -26,11 +26,14 @@ def l1(tup):
|
||||
|
||||
|
||||
def mean_invariate_manhattan_distance(x, y):
|
||||
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
|
||||
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
|
||||
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
|
||||
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
|
||||
# Idea was to find weight sets that have same values but just in different positions, that would
|
||||
# make this distance 0.
|
||||
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
|
||||
try:
|
||||
return np.mean(list(map(l1, zip(sorted(x.detach().numpy()), sorted(y.detach().numpy())))))
|
||||
except AttributeError:
|
||||
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
|
||||
|
||||
|
||||
def distance_matrix(nets, distance="MIM", print_it=True):
|
||||
|
@ -8,6 +8,7 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.metrics import mean_absolute_error as MAE
|
||||
from sklearn.metrics import mean_squared_error as MSE
|
||||
@ -16,6 +17,7 @@ from tqdm import tqdm
|
||||
|
||||
from functionalities_test import is_identity_function, test_status, is_zero_fixpoint, is_divergent, \
|
||||
is_secondary_fixpoint
|
||||
from journal_basins import mean_invariate_manhattan_distance
|
||||
from network import Net
|
||||
from visualization import plot_loss, plot_3d_soup
|
||||
|
||||
@ -25,14 +27,6 @@ def l1(tup):
|
||||
return abs(a - b)
|
||||
|
||||
|
||||
def mean_invariate_manhattan_distance(x, y):
|
||||
# One of these one-liners that might be smart or really dumb. Goal is to find pairwise
|
||||
# distances of ascending values, ie. sum (abs(min1_X-min1_Y), abs(min2_X-min2Y) ...) / mean.
|
||||
# Idea was to find weight sets that have same values but just in different positions, that would
|
||||
# make this distance 0.
|
||||
return np.mean(list(map(l1, zip(sorted(x.numpy()), sorted(y.numpy())))))
|
||||
|
||||
|
||||
def distance_matrix(nets, distance="MIM", print_it=True):
|
||||
matrix = [[0 for _ in range(len(nets))] for _ in range(len(nets))]
|
||||
for net in range(len(nets)):
|
||||
|
102
network.py
102
network.py
@ -34,17 +34,21 @@ class Net(nn.Module):
|
||||
def are_weights_diverged(network_weights):
|
||||
""" Testing if the weights are eiter converging to infinity or -infinity. """
|
||||
|
||||
for layer_id, layer in enumerate(network_weights):
|
||||
for cell_id, cell in enumerate(layer):
|
||||
for weight_id, weight in enumerate(cell):
|
||||
if torch.isnan(weight):
|
||||
return True
|
||||
if torch.isinf(weight):
|
||||
return True
|
||||
return False
|
||||
# Slow and shitty:
|
||||
# for layer_id, layer in enumerate(network_weights):
|
||||
# for cell_id, cell in enumerate(layer):
|
||||
# for weight_id, weight in enumerate(cell):
|
||||
# if torch.isnan(weight):
|
||||
# return True
|
||||
# if torch.isinf(weight):
|
||||
# return True
|
||||
# return False
|
||||
# Fast and modern:
|
||||
return any(x.isnan.any() or x.isinf().any() for x in network_weights.parameters)
|
||||
|
||||
def apply_weights(self, new_weights: Tensor):
|
||||
""" Changing the weights of a network to new given values. """
|
||||
# TODO: Change this to 'parameters' version
|
||||
i = 0
|
||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||
@ -101,15 +105,17 @@ class Net(nn.Module):
|
||||
# Cell Enumeration
|
||||
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
|
||||
# Weight Enumeration within the Cells
|
||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
|
||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1),
|
||||
*(torch.full((x.numel(), 1), 0, device=d) for _ in range(self.input_size-4))
|
||||
), dim=1)
|
||||
)
|
||||
# Finalize
|
||||
weight_matrix = torch.cat(weight_matrix).float()
|
||||
|
||||
# Normalize all along the 1 dimensions
|
||||
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
|
||||
# Normalize 1,2,3 column of dim 1
|
||||
last_pos_idx = self.input_size - 4
|
||||
norm2 = weight_matrix[:, 1:-last_pos_idx].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||
weight_matrix[:, 1:-last_pos_idx] = (weight_matrix[:, 1:-last_pos_idx] / norm2) + 1e-8
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
@ -117,7 +123,7 @@ class Net(nn.Module):
|
||||
mask[:, 0] = 0
|
||||
|
||||
self._weight_pos_enc_and_mask = weight_matrix, mask
|
||||
return self._weight_pos_enc_and_mask
|
||||
return tuple(x.clone() for x in self._weight_pos_enc_and_mask)
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
@ -125,6 +131,7 @@ class Net(nn.Module):
|
||||
return x
|
||||
|
||||
def normalize(self, value, norm):
|
||||
raise NotImplementedError
|
||||
# FIXME, This is bullshit, the code does not do what the docstring explains
|
||||
# Obsolete now
|
||||
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
|
||||
@ -138,7 +145,7 @@ class Net(nn.Module):
|
||||
""" Calculating the input tensor formed from the weights of the net """
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
pos_enc, mask = self._weight_pos_enc
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, pos_enc.shape[-1]) * (1 - mask)
|
||||
return weight_matrix
|
||||
|
||||
def self_train(self,
|
||||
@ -283,33 +290,50 @@ class SecondaryNet(Net):
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface, residual_skip=True):
|
||||
def __init__(self, name, interface):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.weight_interface = 4
|
||||
self.weight_interface = 5
|
||||
self.net_hidden_size = 4
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList()
|
||||
self.meta_weight_list.extend(
|
||||
[Net(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
self.net_ouput_size, name=f'{self.name}_W{weight_idx}'
|
||||
) for weight_idx in range(self.interface)]
|
||||
)
|
||||
self.__bed_mask = None
|
||||
|
||||
@property
|
||||
def _bed_mask(self):
|
||||
if self.__bed_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
embedding = torch.zeros(1, self.weight_interface, device=d)
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(embedding)
|
||||
mask[:, -1] = 0
|
||||
|
||||
self.__bed_mask = embedding, mask
|
||||
return tuple(x.clone() for x in self.__bed_mask)
|
||||
|
||||
def forward(self, x):
|
||||
xs = [torch.hstack(
|
||||
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
|
||||
)
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
embedding, mask = self._bed_mask
|
||||
expanded_mask = mask.expand(*x.shape, embedding.shape[-1])
|
||||
embedding = embedding.repeat(*x.shape, 1)
|
||||
|
||||
if self.residual_skip:
|
||||
tensor += x
|
||||
# Row-wise
|
||||
# xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1]).swapdims(0, 1)
|
||||
# Column-wise
|
||||
xs = x.unsqueeze(-1).expand(-1, -1, embedding.shape[-1])
|
||||
xs = embedding * expanded_mask + xs * (1 - expanded_mask)
|
||||
# ToDo Speed this up!
|
||||
tensor = torch.hstack([meta_weight(xs[:, idx, :]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
|
||||
result = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return result
|
||||
tensor = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
@ -317,21 +341,27 @@ class MetaCell(nn.Module):
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, width=4):
|
||||
def __init__(self, name, interface=4, width=4, residual_skip=True):
|
||||
super().__init__()
|
||||
self.residual_skip = residual_skip
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
self.meta_cell_list = nn.ModuleList()
|
||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_C{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
||||
return result
|
||||
cell_results = []
|
||||
for metacell in self.meta_cell_list:
|
||||
cell_results.append(metacell(x))
|
||||
tensor = torch.hstack(cell_results)
|
||||
if self.residual_skip and x.shape == tensor.shape:
|
||||
tensor += x
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
@ -349,15 +379,15 @@ class MetaNet(nn.Module):
|
||||
self.depth = depth
|
||||
|
||||
self._meta_layer_list = nn.ModuleList()
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
|
||||
self._meta_layer_list.append(MetaLayer(name=f'L{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)
|
||||
)
|
||||
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
self._meta_layer_list.extend([MetaLayer(name=f'L{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)]
|
||||
)
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
|
||||
self._meta_layer_list.append(MetaLayer(name=f'L{len(self._meta_layer_list)}',
|
||||
interface=self.width, width=self.out)
|
||||
)
|
||||
|
||||
@ -383,9 +413,9 @@ class MetaNet(nn.Module):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
|
||||
metanet = MetaNet(interface=3, depth=5, width=3, out=1)
|
||||
next(metanet.particles).input_weight_matrix()
|
||||
metanet(torch.ones((5, 2)))
|
||||
metanet(torch.hstack([torch.full((2, 1), x) for x in range(metanet.interface)]))
|
||||
a = metanet.particles
|
||||
print('Test')
|
||||
print('Test')
|
||||
|
@ -2,10 +2,13 @@ torch~=1.8.1+cpu
|
||||
tqdm~=4.60.0
|
||||
numpy~=1.20.3
|
||||
matplotlib~=3.4.2
|
||||
sklearn
|
||||
sklearn~=0.0
|
||||
scipy
|
||||
tabulate~=0.8.9
|
||||
|
||||
scikit-learn~=0.24.2
|
||||
pandas~=1.2.4
|
||||
seaborn~=0.11.1
|
||||
seaborn~=0.11.1
|
||||
future~=0.18.2
|
||||
torchmetrics~=0.7.0
|
||||
torchvision~=0.9.1+cpu
|
Loading…
x
Reference in New Issue
Block a user