MetaNetworks
This commit is contained in:
@ -1,15 +1,59 @@
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import platform
|
||||
|
||||
import pandas as pd
|
||||
import torchmetrics
|
||||
|
||||
if platform.node() != 'CarbonX':
|
||||
debug = False
|
||||
try:
|
||||
# noinspection PyUnboundLocalVariable
|
||||
if __package__ is None:
|
||||
DIR = Path(__file__).resolve().parent
|
||||
sys.path.insert(0, str(DIR.parent))
|
||||
__package__ = DIR.name
|
||||
else:
|
||||
DIR = None
|
||||
except NameError:
|
||||
DIR = None
|
||||
pass
|
||||
else:
|
||||
debug = True
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
from torch import nn
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose
|
||||
from tqdm import tqdm
|
||||
|
||||
from network import MetaNet
|
||||
|
||||
WORKER = 10 if not debug else 2
|
||||
BATCHSIZE = 500 if not debug else 50
|
||||
EPOCH = 50 if not debug else 3
|
||||
VALIDATION_FRQ = 5 if not debug else 1
|
||||
SELF_TRAIN_FRQ = 1 if not debug else 1
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
class TaskDataset(Dataset):
|
||||
if debug:
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
|
||||
class ToFloat:
|
||||
|
||||
def __call__(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
||||
class AddTaskDataset(Dataset):
|
||||
def __init__(self, length=int(5e5)):
|
||||
super().__init__()
|
||||
self.length = length
|
||||
@ -23,28 +67,164 @@ class TaskDataset(Dataset):
|
||||
return ab, ab.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def set_checkpoint(model, out_path, epoch_n, final_model=False):
|
||||
epoch_n = str(epoch_n)
|
||||
if final_model:
|
||||
ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
|
||||
else:
|
||||
ckpt_path = Path(out_path) / f'trained_model_ckpt.tp'
|
||||
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
torch.save(model, ckpt_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return ckpt_path
|
||||
|
||||
|
||||
def validate(checkpoint_path, ratio=0.1):
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
import torchmetrics
|
||||
|
||||
# initialize metric
|
||||
metric = torchmetrics.Accuracy()
|
||||
|
||||
try:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False)
|
||||
except RuntimeError:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, train=False, download=True)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
model = torch.load(checkpoint_path, map_location=DEVICE).eval()
|
||||
n_samples = int(len(d) * ratio)
|
||||
|
||||
with tqdm(total=n_samples, desc='Validation Run: ') as pbar:
|
||||
for idx, (batch_x, batch_y) in enumerate(d):
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = model(batch_x)
|
||||
|
||||
# metric on current batch
|
||||
acc = metric(y.cpu(), batch_y.cpu())
|
||||
pbar.set_postfix_str(f'Acc: {acc}')
|
||||
pbar.update()
|
||||
if idx == n_samples:
|
||||
break
|
||||
|
||||
# metric on all batches using custom accumulation
|
||||
acc = metric.compute()
|
||||
print(f"Accuracy on all data: {acc}")
|
||||
return acc
|
||||
|
||||
|
||||
def checkpoint_and_validate(model, out_path, epoch_n, final_model=False):
|
||||
out_path = Path(out_path)
|
||||
ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
|
||||
result = validate(ckpt_path)
|
||||
return result
|
||||
|
||||
|
||||
def plot_training_result(path_to_dataframe):
|
||||
# load from Drive
|
||||
df = pd.read_csv(path_to_dataframe, index_col=0)
|
||||
|
||||
fig, ax1 = plt.subplots() # initializes figure and plots
|
||||
ax2 = ax1.twinx() # applies twinx to ax2, which is the second y axis.
|
||||
|
||||
# plots the first set of data, and sets it to ax1.
|
||||
data = df[df['Metric'] == 'BatchLoss']
|
||||
# plots the second set, and sets to ax2.
|
||||
sns.lineplot(data=data.groupby('Epoch').mean(), x='Epoch', y='Score', legend=True, ax=ax2)
|
||||
data = df[df['Metric'] == 'Test Accuracy']
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='red')
|
||||
data = df[df['Metric'] == 'Train Accuracy']
|
||||
sns.lineplot(data=data, x='Epoch', y='Score', marker='o', color='green')
|
||||
|
||||
ax2.set(yscale='log')
|
||||
ax1.set_title('Training Lineplot')
|
||||
plt.tight_layout()
|
||||
if debug:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(Path(path_to_dataframe.parent / 'training_lineplot.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(2, 3, 4, 1)
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.AdamW(metanet.parameters(), lr=0.004)
|
||||
|
||||
d = DataLoader(TaskDataset(), batch_size=50, shuffle=True, drop_last=True)
|
||||
# metanet.train(True)
|
||||
losses = []
|
||||
for batch_x, batch_y in tqdm(d, total=len(d)):
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
self_train = True
|
||||
soup_interaction = True
|
||||
training = True
|
||||
plotting = True
|
||||
|
||||
y = metanet(batch_x)
|
||||
loss = loss_fn(y, batch_y)
|
||||
loss.backward()
|
||||
data_path = Path('data')
|
||||
data_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
run_path = Path('output') / 'intergrated_self_train'
|
||||
model_path = run_path / '0000_trained_model.zip'
|
||||
|
||||
losses.append(loss.item())
|
||||
if training:
|
||||
utility_transforms = Compose([ToTensor(), ToFloat(), Flatten(start_dim=0)])
|
||||
|
||||
sns.lineplot(y=np.asarray(losses), x=np.arange(len(losses)))
|
||||
plt.show()
|
||||
try:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||
except RuntimeError:
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms, download=True)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
interface = np.prod(dataset[0][0].shape)
|
||||
metanet = MetaNet(interface, depth=4, width=6, out=10).to(DEVICE).train()
|
||||
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.SGD(metanet.parameters(), lr=0.004, momentum=0.9)
|
||||
|
||||
train_store = pd.DataFrame(columns=['Epoch', 'Batch', 'Metric', 'Score'])
|
||||
for epoch in tqdm(range(EPOCH), desc='MetaNet Train - Epochs'):
|
||||
is_validation_epoch = epoch % VALIDATION_FRQ == 0 if not debug else True
|
||||
is_self_train_epoch = epoch % SELF_TRAIN_FRQ == 0 if not debug else True
|
||||
if is_validation_epoch:
|
||||
metric = torchmetrics.Accuracy()
|
||||
for batch, (batch_x, batch_y) in tqdm(enumerate(d), total=len(d), desc='MetaNet Train - Batch'):
|
||||
if self_train and is_self_train_epoch:
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
combined_self_train_loss = metanet.combined_self_train()
|
||||
combined_self_train_loss.backward()
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
# Zero your gradients for every batch!
|
||||
optimizer.zero_grad()
|
||||
batch_x, batch_y = batch_x.to(DEVICE), batch_y.to(DEVICE)
|
||||
y = metanet(batch_x)
|
||||
# loss = loss_fn(y, batch_y.unsqueeze(-1).to(torch.float32))
|
||||
loss = loss_fn(y, batch_y.to(torch.long))
|
||||
loss.backward()
|
||||
|
||||
# Adjust learning weights
|
||||
optimizer.step()
|
||||
|
||||
step_log = dict(Epoch=epoch, Batch=batch,
|
||||
Metric='BatchLoss', Score=loss.item())
|
||||
train_store.loc[train_store.shape[0]] = step_log
|
||||
if is_validation_epoch:
|
||||
metric(y.cpu(), batch_y.cpu())
|
||||
|
||||
if batch >= 3 and debug:
|
||||
break
|
||||
|
||||
if is_validation_epoch:
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric='Train Accuracy', Score=metric.compute().item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
accuracy = checkpoint_and_validate(metanet, run_path, epoch)
|
||||
validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
|
||||
Metric='Test Accuracy', Score=accuracy.item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
accuracy = checkpoint_and_validate(metanet, run_path, EPOCH, final_model=True)
|
||||
validation_log = dict(Epoch=EPOCH, Batch=BATCHSIZE,
|
||||
Metric='Test Accuracy', Score=accuracy.item())
|
||||
train_store.loc[train_store.shape[0]] = validation_log
|
||||
|
||||
torch.save(metanet, model_path, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
train_store.to_csv(run_path / 'train_store.csv')
|
||||
|
||||
if plotting:
|
||||
plot_training_result(run_path / 'train_store.csv')
|
||||
|
@ -282,7 +282,7 @@ if __name__ == "__main__":
|
||||
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
# Catplot of children L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
|
||||
# Catplot of child_nets L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
|
||||
df_acc = pd.DataFrame(columns=["name", "noise", "l1_acc", "Network Type"])
|
||||
for i in range(len(exp_list)):
|
||||
noise = exp_list[i].noise
|
||||
@ -297,10 +297,10 @@ if __name__ == "__main__":
|
||||
mse_loss = F.mse_loss(target_data, predicted_values).item()
|
||||
l1_loss = F.l1_loss(target_data, predicted_values).item()
|
||||
|
||||
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "children"]
|
||||
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "child_nets"]
|
||||
print("MSE:", mse_loss, "\t", "L1: ", l1_loss, "\t", network.name)
|
||||
|
||||
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and children accuracy is too far apart this plot might not work (only shows either parents or part of the children).
|
||||
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and child_nets accuracy is too far apart this plot might not work (only shows either parents or part of the child_nets).
|
||||
ax = sns.catplot(data=df_acc, y="l1_acc", x="noise", hue="Network Type", kind="box", legend=False, showfliers=False, height=5.27, aspect=11.7/5.27, sharey=False)
|
||||
ax.map(plt.axhline, y=10**-6, ls='--')
|
||||
ax.map(plt.axhline, y=10**-7, ls='--')
|
||||
|
@ -216,7 +216,7 @@ class SoupSpawnExperiment:
|
||||
|
||||
df.loc[len(df)] = [clone.name, net.name, MAE_pre, 0, MSE_pre, 0, MIM_pre, 0, self.noise, ""]
|
||||
|
||||
net.children.append(clone)
|
||||
net.child_nets.append(clone)
|
||||
self.clones.append(clone)
|
||||
self.parents_with_clones.append(clone)
|
||||
|
||||
@ -229,8 +229,8 @@ class SoupSpawnExperiment:
|
||||
net_input_data = net.input_weight_matrix()
|
||||
net_target_data = net.create_target_weights(net_input_data)
|
||||
|
||||
for j in range(len(net.children)):
|
||||
clone = net.children[j]
|
||||
for j in range(len(net.child_nets)):
|
||||
clone = net.child_nets[j]
|
||||
|
||||
# Post Training distances for comparison
|
||||
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())
|
||||
|
222
network.py
222
network.py
@ -1,14 +1,13 @@
|
||||
# from __future__ import annotations
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
from math import sqrt
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch import optim, Tensor
|
||||
|
||||
|
||||
@ -22,12 +21,14 @@ class Net(nn.Module):
|
||||
def create_target_weights(input_weight_matrix: Tensor) -> Tensor:
|
||||
""" Outputting a tensor with the target weights. """
|
||||
|
||||
target_weight_matrix = np.arange(len(input_weight_matrix)).reshape(len(input_weight_matrix), 1).astype("f")
|
||||
# What kind of slow shit is this?
|
||||
# target_weight_matrix = np.arange(len(input_weight_matrix)).reshape(len(input_weight_matrix), 1).astype("f")
|
||||
# for i in range(len(input_weight_matrix)):
|
||||
# target_weight_matrix[i] = input_weight_matrix[i][0]
|
||||
|
||||
for i in range(len(input_weight_matrix)):
|
||||
target_weight_matrix[i] = input_weight_matrix[i][0]
|
||||
# Fast and simple
|
||||
return input_weight_matrix[:, 0].unsqueeze(-1)
|
||||
|
||||
return torch.from_numpy(target_weight_matrix)
|
||||
|
||||
@staticmethod
|
||||
def are_weights_diverged(network_weights):
|
||||
@ -36,9 +37,9 @@ class Net(nn.Module):
|
||||
for layer_id, layer in enumerate(network_weights):
|
||||
for cell_id, cell in enumerate(layer):
|
||||
for weight_id, weight in enumerate(cell):
|
||||
if np.isnan(weight):
|
||||
if torch.isnan(weight):
|
||||
return True
|
||||
if np.isinf(weight):
|
||||
if torch.isinf(weight):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -58,7 +59,7 @@ class Net(nn.Module):
|
||||
self.start_time = start_time
|
||||
|
||||
self.name = name
|
||||
self.children = []
|
||||
self.child_nets = []
|
||||
|
||||
self.input_size = i_size
|
||||
self.hidden_size = h_size
|
||||
@ -74,19 +75,58 @@ class Net(nn.Module):
|
||||
self.number_trained = 0
|
||||
|
||||
self.is_fixpoint = ""
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Linear(i_size, h_size, False),
|
||||
nn.Linear(h_size, h_size, False),
|
||||
nn.Linear(h_size, o_size, False)]
|
||||
)
|
||||
|
||||
self.fc1 = nn.Linear(i_size, h_size, False)
|
||||
self.fc2 = nn.Linear(h_size, h_size, False)
|
||||
self.fc3 = nn.Linear(h_size, o_size, False)
|
||||
self._weight_pos_enc_and_mask = None
|
||||
|
||||
|
||||
@property
|
||||
def _weight_pos_enc(self):
|
||||
if self._weight_pos_enc_and_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
weight_matrix = []
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
x = next(layer.parameters())
|
||||
weight_matrix.append(
|
||||
torch.cat(
|
||||
(
|
||||
# Those are the weights
|
||||
torch.full((x.numel(), 1), 0, device=d),
|
||||
# Layer enumeration
|
||||
torch.full((x.numel(), 1), layer_id, device=d),
|
||||
# Cell Enumeration
|
||||
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
|
||||
# Weight Enumeration within the Cells
|
||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
|
||||
), dim=1)
|
||||
)
|
||||
# Finalize
|
||||
weight_matrix = torch.cat(weight_matrix).float()
|
||||
|
||||
# Normalize all along the 1 dimensions
|
||||
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(weight_matrix)
|
||||
mask[:, 0] = 0
|
||||
|
||||
self._weight_pos_enc_and_mask = weight_matrix, mask
|
||||
return self._weight_pos_enc_and_mask
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.fc3(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def normalize(self, value, norm):
|
||||
# FIXME, This is bullshit, the code does not do what the docstring explains
|
||||
# Obsolete now
|
||||
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
|
||||
|
||||
if norm > 1:
|
||||
@ -96,23 +136,17 @@ class Net(nn.Module):
|
||||
|
||||
def input_weight_matrix(self) -> Tensor:
|
||||
""" Calculating the input tensor formed from the weights of the net """
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
pos_enc, mask = self._weight_pos_enc
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
|
||||
return weight_matrix
|
||||
|
||||
# The "4" represents the weightwise coordinates used for the matrix: <value><layer_id><cell_id><positional_id>
|
||||
weight_matrix = np.arange(self.no_weights * 4).reshape(self.no_weights, 4).astype("f")
|
||||
|
||||
i = 0
|
||||
max_layer_id = len(self.state_dict()) - 1
|
||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||
max_cell_id = len(self.state_dict()[layer_name]) - 1
|
||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||
max_weight_id = len(line_values) - 1
|
||||
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
|
||||
weight_matrix[i] = weight_value.item(), self.normalize(layer_id, max_layer_id), self.normalize(line_id, max_cell_id), self.normalize(weight_id, max_weight_id)
|
||||
i += 1
|
||||
|
||||
return torch.from_numpy(weight_matrix)
|
||||
|
||||
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (np.ndarray, Tensor, list):
|
||||
def self_train(self,
|
||||
training_steps: int,
|
||||
log_step_size: int = 0,
|
||||
learning_rate: float = 0.0004,
|
||||
save_history: bool = True
|
||||
) -> (Tensor, list):
|
||||
""" Training a network to predict its own weights in order to self-replicate. """
|
||||
|
||||
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
|
||||
@ -127,27 +161,30 @@ class Net(nn.Module):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
|
||||
# If it is a soup/mixed env. save weights only at the end of all training steps (aka a soup/mixed epoch)
|
||||
if "soup" not in self.name and "mixed" not in self.name:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# If self-training steps are lower than 10, then append weight history after each ST step.
|
||||
if self.number_trained < 10:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.detach().numpy().item())
|
||||
else:
|
||||
if self.number_trained % log_step_size == 0:
|
||||
if save_history:
|
||||
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
|
||||
# If it is a soup/mixed env. save weights only at the end of all training steps (aka a soup/mixed epoch)
|
||||
if "soup" not in self.name and "mixed" not in self.name:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# If self-training steps are lower than 10, then append weight history after each ST step.
|
||||
if self.number_trained < 10:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.detach().numpy().item())
|
||||
self.loss_history.append(loss.item())
|
||||
else:
|
||||
if log_step_size != 0:
|
||||
if self.number_trained % log_step_size == 0:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# Saving weights only at the end of a soup/mixed exp. epoch.
|
||||
if "soup" in self.name or "mixed" in self.name:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.detach().numpy().item())
|
||||
if save_history:
|
||||
if "soup" in self.name or "mixed" in self.name:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
self.trained = True
|
||||
return weights.detach().numpy(), loss, self.loss_history
|
||||
return loss, self.loss_history
|
||||
|
||||
def self_application(self, SA_steps: int, log_step_size: Union[int, None] = None):
|
||||
""" Inputting the weights of a network to itself for a number of steps, without backpropagation. """
|
||||
@ -208,7 +245,7 @@ class Net(nn.Module):
|
||||
|
||||
class SecondaryNet(Net):
|
||||
|
||||
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (np.ndarray, Tensor, list):
|
||||
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (pd.DataFrame, Tensor, list):
|
||||
""" Training a network to predict its own weights in order to self-replicate. """
|
||||
|
||||
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
|
||||
@ -245,10 +282,6 @@ class SecondaryNet(Net):
|
||||
return df, is_diverged
|
||||
|
||||
|
||||
class MetaWeight(Net):
|
||||
pass
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface, residual_skip=True):
|
||||
super().__init__()
|
||||
@ -258,67 +291,102 @@ class MetaCell(nn.Module):
|
||||
self.weight_interface = 4
|
||||
self.net_hidden_size = 4
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList(
|
||||
[MetaWeight(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
) for weight_idx in range(self.interface)])
|
||||
self.meta_weight_list = nn.ModuleList()
|
||||
self.meta_weight_list.extend(
|
||||
[Net(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
) for weight_idx in range(self.interface)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
xs = [torch.hstack(
|
||||
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
|
||||
)
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
|
||||
if self.residual_skip:
|
||||
tensor += x
|
||||
|
||||
result = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return result
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (net for net in self.meta_weight_list)
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, out=1, width=4):
|
||||
def __init__(self, name, interface=4, width=4):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)])
|
||||
self.meta_cell_list = meta_cell_list
|
||||
self.meta_cell_list = nn.ModuleList()
|
||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
||||
return result
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (weight for metacell in self.meta_cell_list for weight in metacell.particles)
|
||||
|
||||
|
||||
class MetaNet(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1):
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
|
||||
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)])
|
||||
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)])
|
||||
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
|
||||
interface=self.width, width=self.out))
|
||||
self._meta_layer_list = meta_layer_list
|
||||
self._net = nn.Sequential(*self._meta_layer_list)
|
||||
self._meta_layer_list = nn.ModuleList()
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)
|
||||
)
|
||||
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)]
|
||||
)
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
|
||||
interface=self.width, width=self.out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
result = self._net.forward(x)
|
||||
return result
|
||||
tensor = x
|
||||
for meta_layer in self._meta_layer_list:
|
||||
tensor = meta_layer(tensor)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
|
||||
|
||||
def combined_self_train(self):
|
||||
losses = []
|
||||
for particle in self.particles:
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
losses.append(F.mse_loss(output, target_data))
|
||||
return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(2, 3, 4, 1)
|
||||
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
|
||||
next(metanet.particles).input_weight_matrix()
|
||||
metanet(torch.ones((5, 2)))
|
||||
a = metanet.particles
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
|
Reference in New Issue
Block a user