MetaNetworks

This commit is contained in:
Steffen Illium
2022-01-26 16:56:05 +01:00
parent 5f1f5833d8
commit 49c0d8a621
4 changed files with 349 additions and 101 deletions

View File

@ -1,15 +1,59 @@
import pickle
import time
from pathlib import Path
import sys
import platform
import pandas as pd
import torchmetrics
if platform.node() != 'CarbonX':
debug = False
try:
# noinspection PyUnboundLocalVariable
if __package__ is None:
DIR = Path(__file__).resolve().parent
sys.path.insert(0, str(DIR.parent))
__package__ = DIR.name
else:
DIR = None
except NameError:
DIR = None
pass
else:
debug = True
import numpy as np
import torch
from matplotlib import pyplot as plt
import seaborn as sns
from torch import nn
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose
from tqdm import tqdm
from network import MetaNet
WORKER = 10 if not debug else 2
BATCHSIZE = 500 if not debug else 50
EPOCH = 50 if not debug else 3
VALIDATION_FRQ = 5 if not debug else 1
SELF_TRAIN_FRQ = 1 if not debug else 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class TaskDataset(Dataset):
if debug:
torch.autograd.set_detect_anomaly(True)
class ToFloat:
def __call__(self, x):
return x.to(torch.float32)
class AddTaskDataset(Dataset):
def __init__(self, length=int(5e5)):
super().__init__()
self.length = length
@ -23,28 +67,164 @@ class TaskDataset(Dataset):
return ab, ab.sum(axis=-1, keepdims=True)
def set_checkpoint(model, out_path, epoch_n, final_model=False):
epoch_n = str(epoch_n)
if final_model:
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
else:
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
return ckpt_path
def validate(checkpoint_path, ratio=0.1):
checkpoint_path = Path(checkpoint_path)
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
try:
dataset = MNIST(str(data_path), transform=utility_transforms, train=False)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
n_samples = int(len(d) * ratio)
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
for idx, (batch_x, batch_y) in enumerate(d):
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = model(batch_x)
# metric on current batch
acc = metric(y.cpu(), batch_y.cpu())
pbar.set_postfix_str(f'Acc: {acc}')
pbar.update()
if idx == n_samples:
break
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
return acc
def checkpoint_and_validate(model, out_path, epoch_n, final_model=False):
out_path = Path(out_path)
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
result = validate(ckpt_path)
return result
def plot_training_result(path_to_dataframe):
# load from Drive
df = pd.read_csv(path_to_dataframe, index_col=0)
fig, ax1 = plt.subplots() # initializes figure and plots
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y axis.
# plots the first set of data, and sets it to ax1.
data = df[df['Metric'] == 'BatchLoss']
# plots the second set, and sets to ax2.
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax2)
data = df[df['Metric'] == 'Test Accuracy']
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='red')
data = df[df['Metric'] == 'Train Accuracy']
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='green')
ax2.set(yscale='log')
ax1.set_title('Training Lineplot')
plt.tight_layout()
if debug:
plt.show()
else:
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'))
if __name__ == '__main__':
metanet = MetaNet(2, 3, 4, 1)
loss_fn = nn.MSELoss()
optimizer = torch.optim.AdamW(metanet.parameters(), lr=0.004)
d = DataLoader(TaskDataset(), batch_size=50, shuffle=True, drop_last=True)
# metanet.train(True)
losses = []
for batch_x, batch_y in tqdm(d, total=len(d)):
# Zero your gradients for every batch!
optimizer.zero_grad()
self_train = True
soup_interaction = True
training = True
plotting = True
y = metanet(batch_x)
loss = loss_fn(y, batch_y)
loss.backward()
data_path = Path('data')
data_path.mkdir(exist_ok=True, parents=True)
# Adjust learning weights
optimizer.step()
run_path = Path('output') / 'intergrated_self_train'
model_path = run_path / '0000_trained_model.zip'
losses.append(loss.item())
if training:
utility_transforms = Compose([ToTensor(), ToFloat(), Flatten(start_dim=0)])
sns.lineplot(y=np.asarray(losses), x=np.arange(len(losses)))
plt.show()
try:
dataset = MNIST(str(data_path), transform=utility_transforms)
except RuntimeError:
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
interface = np.prod(dataset[0][0].shape)
metanet = MetaNet(interface, depth=4, width=6, out=10).to(DEVICE).train()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
train_store = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
if is_validation_epoch:
metric = torchmetrics.Accuracy()
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
if self_train and is_self_train_epoch:
# Zero your gradients for every batch!
optimizer.zero_grad()
combined_self_train_loss = metanet.combined_self_train()
combined_self_train_loss.backward()
# Adjust learning weights
optimizer.step()
# Zero your gradients for every batch!
optimizer.zero_grad()
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
y = metanet(batch_x)
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
loss = loss_fn(y, batch_y.to(torch.long))
loss.backward()
# Adjust learning weights
optimizer.step()
step_log = dict(Epoch=epoch, Batch=batch,
Metric='BatchLoss', Score=loss.item())
train_store.loc[train_store.shape[0]] = step_log
if is_validation_epoch:
metric(y.cpu(), batch_y.cpu())
if batch >= 3 and debug:
break
if is_validation_epoch:
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Train Accuracy', Score=metric.compute().item())
train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(metanet, run_path, epoch)
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
Metric='Test Accuracy', Score=accuracy.item())
train_store.loc[train_store.shape[0]] = validation_log
torch.save(metanet, model_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
train_store.to_csv(run_path / 'train_store.csv')
if plotting:
plot_training_result(run_path / 'train_store.csv')

View File

@ -282,7 +282,7 @@ if __name__ == "__main__":
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
plt.clf()
# Catplot of children L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
# Catplot of child_nets L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
df_acc = pd.DataFrame(columns=["name", "noise", "l1_acc", "Network Type"])
for i in range(len(exp_list)):
noise = exp_list[i].noise
@ -297,10 +297,10 @@ if __name__ == "__main__":
mse_loss = F.mse_loss(target_data, predicted_values).item()
l1_loss = F.l1_loss(target_data, predicted_values).item()
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "children"]
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "child_nets"]
print("MSE:", mse_loss, "\t", "L1: ", l1_loss, "\t", network.name)
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and children accuracy is too far apart this plot might not work (only shows either parents or part of the children).
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and child_nets accuracy is too far apart this plot might not work (only shows either parents or part of the child_nets).
ax = sns.catplot(data=df_acc, y="l1_acc", x="noise", hue="Network Type", kind="box", legend=False, showfliers=False, height=5.27, aspect=11.7/5.27, sharey=False)
ax.map(plt.axhline, y=10**-6, ls='--')
ax.map(plt.axhline, y=10**-7, ls='--')

View File

@ -216,7 +216,7 @@ class SoupSpawnExperiment:
df.loc[len(df)] = [clone.name, net.name, MAE_pre, 0, MSE_pre, 0, MIM_pre, 0, self.noise, ""]
net.children.append(clone)
net.child_nets.append(clone)
self.clones.append(clone)
self.parents_with_clones.append(clone)
@ -229,8 +229,8 @@ class SoupSpawnExperiment:
net_input_data = net.input_weight_matrix()
net_target_data = net.create_target_weights(net_input_data)
for j in range(len(net.children)):
clone = net.children[j]
for j in range(len(net.child_nets)):
clone = net.child_nets[j]
# Post Training distances for comparison
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())

View File

@ -1,14 +1,13 @@
# from __future__ import annotations
import copy
import inspect
import random
from math import sqrt
from typing import Union
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import optim, Tensor
@ -22,12 +21,14 @@ class Net(nn.Module):
def create_target_weights(input_weight_matrix: Tensor) -> Tensor:
""" Outputting a tensor with the target weights. """
target_weight_matrix = np.arange(len(input_weight_matrix)).reshape(len(input_weight_matrix), 1).astype("f")
# What kind of slow shit is this?
# target_weight_matrix = np.arange(len(input_weight_matrix)).reshape(len(input_weight_matrix), 1).astype("f")
# for i in range(len(input_weight_matrix)):
# target_weight_matrix[i] = input_weight_matrix[i][0]
for i in range(len(input_weight_matrix)):
target_weight_matrix[i] = input_weight_matrix[i][0]
# Fast and simple
return input_weight_matrix[:, 0].unsqueeze(-1)
return torch.from_numpy(target_weight_matrix)
@staticmethod
def are_weights_diverged(network_weights):
@ -36,9 +37,9 @@ class Net(nn.Module):
for layer_id, layer in enumerate(network_weights):
for cell_id, cell in enumerate(layer):
for weight_id, weight in enumerate(cell):
if np.isnan(weight):
if torch.isnan(weight):
return True
if np.isinf(weight):
if torch.isinf(weight):
return True
return False
@ -58,7 +59,7 @@ class Net(nn.Module):
self.start_time = start_time
self.name = name
self.children = []
self.child_nets = []
self.input_size = i_size
self.hidden_size = h_size
@ -74,19 +75,58 @@ class Net(nn.Module):
self.number_trained = 0
self.is_fixpoint = ""
self.layers = nn.ModuleList(
[nn.Linear(i_size, h_size, False),
nn.Linear(h_size, h_size, False),
nn.Linear(h_size, o_size, False)]
)
self.fc1 = nn.Linear(i_size, h_size, False)
self.fc2 = nn.Linear(h_size, h_size, False)
self.fc3 = nn.Linear(h_size, o_size, False)
self._weight_pos_enc_and_mask = None
@property
def _weight_pos_enc(self):
if self._weight_pos_enc_and_mask is None:
d = next(self.parameters()).device
weight_matrix = []
for layer_id, layer in enumerate(self.layers):
x = next(layer.parameters())
weight_matrix.append(
torch.cat(
(
# Those are the weights
torch.full((x.numel(), 1), 0, device=d),
# Layer enumeration
torch.full((x.numel(), 1), layer_id, device=d),
# Cell Enumeration
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
# Weight Enumeration within the Cells
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
), dim=1)
)
# Finalize
weight_matrix = torch.cat(weight_matrix).float()
# Normalize all along the 1 dimensions
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
# computations
# create a mask where pos is 0 if it is to be replaced
mask = torch.ones_like(weight_matrix)
mask[:, 0] = 0
self._weight_pos_enc_and_mask = weight_matrix, mask
return self._weight_pos_enc_and_mask
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
for layer in self.layers:
x = layer(x)
return x
def normalize(self, value, norm):
# FIXME, This is bullshit, the code does not do what the docstring explains
# Obsolete now
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
if norm > 1:
@ -96,23 +136,17 @@ class Net(nn.Module):
def input_weight_matrix(self) -> Tensor:
""" Calculating the input tensor formed from the weights of the net """
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
pos_enc, mask = self._weight_pos_enc
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
return weight_matrix
# The "4" represents the weightwise coordinates used for the matrix: <value><layer_id><cell_id><positional_id>
weight_matrix = np.arange(self.no_weights * 4).reshape(self.no_weights, 4).astype("f")
i = 0
max_layer_id = len(self.state_dict()) - 1
for layer_id, layer_name in enumerate(self.state_dict()):
max_cell_id = len(self.state_dict()[layer_name]) - 1
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
max_weight_id = len(line_values) - 1
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
weight_matrix[i] = weight_value.item(), self.normalize(layer_id, max_layer_id), self.normalize(line_id, max_cell_id), self.normalize(weight_id, max_weight_id)
i += 1
return torch.from_numpy(weight_matrix)
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (np.ndarray, Tensor, list):
def self_train(self,
training_steps: int,
log_step_size: int = 0,
learning_rate: float = 0.0004,
save_history: bool = True
) -> (Tensor, list):
""" Training a network to predict its own weights in order to self-replicate. """
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
@ -127,27 +161,30 @@ class Net(nn.Module):
loss.backward()
optimizer.step()
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
# If it is a soup/mixed env. save weights only at the end of all training steps (aka a soup/mixed epoch)
if "soup" not in self.name and "mixed" not in self.name:
weights = self.create_target_weights(self.input_weight_matrix())
# If self-training steps are lower than 10, then append weight history after each ST step.
if self.number_trained < 10:
self.s_train_weights_history.append(weights.T.detach().numpy())
self.loss_history.append(loss.detach().numpy().item())
else:
if self.number_trained % log_step_size == 0:
if save_history:
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
# If it is a soup/mixed env. save weights only at the end of all training steps (aka a soup/mixed epoch)
if "soup" not in self.name and "mixed" not in self.name:
weights = self.create_target_weights(self.input_weight_matrix())
# If self-training steps are lower than 10, then append weight history after each ST step.
if self.number_trained < 10:
self.s_train_weights_history.append(weights.T.detach().numpy())
self.loss_history.append(loss.detach().numpy().item())
self.loss_history.append(loss.item())
else:
if log_step_size != 0:
if self.number_trained % log_step_size == 0:
self.s_train_weights_history.append(weights.T.detach().numpy())
self.loss_history.append(loss.item())
weights = self.create_target_weights(self.input_weight_matrix())
# Saving weights only at the end of a soup/mixed exp. epoch.
if "soup" in self.name or "mixed" in self.name:
self.s_train_weights_history.append(weights.T.detach().numpy())
self.loss_history.append(loss.detach().numpy().item())
if save_history:
if "soup" in self.name or "mixed" in self.name:
self.s_train_weights_history.append(weights.T.detach().numpy())
self.loss_history.append(loss.item())
self.trained = True
return weights.detach().numpy(), loss, self.loss_history
return loss, self.loss_history
def self_application(self, SA_steps: int, log_step_size: Union[int, None] = None):
""" Inputting the weights of a network to itself for a number of steps, without backpropagation. """
@ -208,7 +245,7 @@ class Net(nn.Module):
class SecondaryNet(Net):
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (np.ndarray, Tensor, list):
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (pd.DataFrame, Tensor, list):
""" Training a network to predict its own weights in order to self-replicate. """
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
@ -245,10 +282,6 @@ class SecondaryNet(Net):
return df, is_diverged
class MetaWeight(Net):
pass
class MetaCell(nn.Module):
def __init__(self, name, interface, residual_skip=True):
super().__init__()
@ -258,67 +291,102 @@ class MetaCell(nn.Module):
self.weight_interface = 4
self.net_hidden_size = 4
self.net_ouput_size = 1
self.meta_weight_list = nn.ModuleList(
[MetaWeight(self.weight_interface, self.net_hidden_size,
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
) for weight_idx in range(self.interface)])
self.meta_weight_list = nn.ModuleList()
self.meta_weight_list.extend(
[Net(self.weight_interface, self.net_hidden_size,
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
) for weight_idx in range(self.interface)]
)
def forward(self, x):
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
for idx in range(len(self.meta_weight_list))]
xs = [torch.hstack(
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
)
for idx in range(len(self.meta_weight_list))]
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
if self.residual_skip:
tensor += x
result = torch.sum(tensor, dim=-1, keepdim=True)
return result
@property
def particles(self):
return (net for net in self.meta_weight_list)
class MetaLayer(nn.Module):
def __init__(self, name, interface=4, out=1, width=4):
def __init__(self, name, interface=4, width=4):
super().__init__()
self.name = name
self.interface = interface
self.width = width
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
interface=interface
) for cell_idx in range(self.width)])
self.meta_cell_list = meta_cell_list
self.meta_cell_list = nn.ModuleList()
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
interface=interface
) for cell_idx in range(self.width)]
)
def forward(self, x):
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
return result
@property
def particles(self):
return (weight for metacell in self.meta_cell_list for weight in metacell.particles)
class MetaNet(nn.Module):
def __init__(self, interface=4, depth=3, width=4, out=1):
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None):
super().__init__()
self.activation = activation
self.out = out
self.interface = interface
self.width = width
self.depth = depth
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
interface=self.interface,
width=self.width)])
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
interface=self.width, width=self.width
) for layer_idx in range(self.depth - 2)])
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
interface=self.width, width=self.out))
self._meta_layer_list = meta_layer_list
self._net = nn.Sequential(*self._meta_layer_list)
self._meta_layer_list = nn.ModuleList()
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
interface=self.interface,
width=self.width)
)
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
interface=self.width, width=self.width
) for layer_idx in range(self.depth - 2)]
)
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
interface=self.width, width=self.out)
)
def forward(self, x):
result = self._net.forward(x)
return result
tensor = x
for meta_layer in self._meta_layer_list:
tensor = meta_layer(tensor)
return tensor
@property
def particles(self):
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
def combined_self_train(self):
losses = []
for particle in self.particles:
# Intergrate optimizer and backward function
input_data = particle.input_weight_matrix()
target_data = particle.create_target_weights(input_data)
output = particle(input_data)
losses.append(F.mse_loss(output, target_data))
return torch.hstack(losses).sum(dim=-1, keepdim=True)
if __name__ == '__main__':
metanet = MetaNet(2, 3, 4, 1)
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
next(metanet.particles).input_weight_matrix()
metanet(torch.ones((5, 2)))
a = metanet.particles
print('Test')
print('Test')
print('Test')