MetaNetworks
This commit is contained in:
222
network.py
222
network.py
@ -1,14 +1,13 @@
|
||||
# from __future__ import annotations
|
||||
import copy
|
||||
import inspect
|
||||
import random
|
||||
from math import sqrt
|
||||
from typing import Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torch import optim, Tensor
|
||||
|
||||
|
||||
@ -22,12 +21,14 @@ class Net(nn.Module):
|
||||
def create_target_weights(input_weight_matrix: Tensor) -> Tensor:
|
||||
""" Outputting a tensor with the target weights. """
|
||||
|
||||
target_weight_matrix = np.arange(len(input_weight_matrix)).reshape(len(input_weight_matrix), 1).astype("f")
|
||||
# What kind of slow shit is this?
|
||||
# target_weight_matrix = np.arange(len(input_weight_matrix)).reshape(len(input_weight_matrix), 1).astype("f")
|
||||
# for i in range(len(input_weight_matrix)):
|
||||
# target_weight_matrix[i] = input_weight_matrix[i][0]
|
||||
|
||||
for i in range(len(input_weight_matrix)):
|
||||
target_weight_matrix[i] = input_weight_matrix[i][0]
|
||||
# Fast and simple
|
||||
return input_weight_matrix[:, 0].unsqueeze(-1)
|
||||
|
||||
return torch.from_numpy(target_weight_matrix)
|
||||
|
||||
@staticmethod
|
||||
def are_weights_diverged(network_weights):
|
||||
@ -36,9 +37,9 @@ class Net(nn.Module):
|
||||
for layer_id, layer in enumerate(network_weights):
|
||||
for cell_id, cell in enumerate(layer):
|
||||
for weight_id, weight in enumerate(cell):
|
||||
if np.isnan(weight):
|
||||
if torch.isnan(weight):
|
||||
return True
|
||||
if np.isinf(weight):
|
||||
if torch.isinf(weight):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -58,7 +59,7 @@ class Net(nn.Module):
|
||||
self.start_time = start_time
|
||||
|
||||
self.name = name
|
||||
self.children = []
|
||||
self.child_nets = []
|
||||
|
||||
self.input_size = i_size
|
||||
self.hidden_size = h_size
|
||||
@ -74,19 +75,58 @@ class Net(nn.Module):
|
||||
self.number_trained = 0
|
||||
|
||||
self.is_fixpoint = ""
|
||||
self.layers = nn.ModuleList(
|
||||
[nn.Linear(i_size, h_size, False),
|
||||
nn.Linear(h_size, h_size, False),
|
||||
nn.Linear(h_size, o_size, False)]
|
||||
)
|
||||
|
||||
self.fc1 = nn.Linear(i_size, h_size, False)
|
||||
self.fc2 = nn.Linear(h_size, h_size, False)
|
||||
self.fc3 = nn.Linear(h_size, o_size, False)
|
||||
self._weight_pos_enc_and_mask = None
|
||||
|
||||
|
||||
@property
|
||||
def _weight_pos_enc(self):
|
||||
if self._weight_pos_enc_and_mask is None:
|
||||
d = next(self.parameters()).device
|
||||
weight_matrix = []
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
x = next(layer.parameters())
|
||||
weight_matrix.append(
|
||||
torch.cat(
|
||||
(
|
||||
# Those are the weights
|
||||
torch.full((x.numel(), 1), 0, device=d),
|
||||
# Layer enumeration
|
||||
torch.full((x.numel(), 1), layer_id, device=d),
|
||||
# Cell Enumeration
|
||||
torch.arange(layer.out_features, device=d).repeat_interleave(layer.in_features).view(-1, 1),
|
||||
# Weight Enumeration within the Cells
|
||||
torch.arange(layer.in_features, device=d).view(-1, 1).repeat(layer.out_features, 1)
|
||||
), dim=1)
|
||||
)
|
||||
# Finalize
|
||||
weight_matrix = torch.cat(weight_matrix).float()
|
||||
|
||||
# Normalize all along the 1 dimensions
|
||||
norm2 = weight_matrix[:, 1:].pow(2).sum(keepdim=True, dim=0).sqrt()
|
||||
weight_matrix[:, 1:] = weight_matrix[:, 1:] / norm2
|
||||
|
||||
# computations
|
||||
# create a mask where pos is 0 if it is to be replaced
|
||||
mask = torch.ones_like(weight_matrix)
|
||||
mask[:, 0] = 0
|
||||
|
||||
self._weight_pos_enc_and_mask = weight_matrix, mask
|
||||
return self._weight_pos_enc_and_mask
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.fc3(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def normalize(self, value, norm):
|
||||
# FIXME, This is bullshit, the code does not do what the docstring explains
|
||||
# Obsolete now
|
||||
""" Normalizing the values >= 1 and adding pow(10, -8) to the values equal to 0 """
|
||||
|
||||
if norm > 1:
|
||||
@ -96,23 +136,17 @@ class Net(nn.Module):
|
||||
|
||||
def input_weight_matrix(self) -> Tensor:
|
||||
""" Calculating the input tensor formed from the weights of the net """
|
||||
weight_matrix = torch.cat([x.view(-1, 1) for x in self.parameters()])
|
||||
pos_enc, mask = self._weight_pos_enc
|
||||
weight_matrix = pos_enc * mask + weight_matrix.expand(-1, 4) * (1 - mask)
|
||||
return weight_matrix
|
||||
|
||||
# The "4" represents the weightwise coordinates used for the matrix: <value><layer_id><cell_id><positional_id>
|
||||
weight_matrix = np.arange(self.no_weights * 4).reshape(self.no_weights, 4).astype("f")
|
||||
|
||||
i = 0
|
||||
max_layer_id = len(self.state_dict()) - 1
|
||||
for layer_id, layer_name in enumerate(self.state_dict()):
|
||||
max_cell_id = len(self.state_dict()[layer_name]) - 1
|
||||
for line_id, line_values in enumerate(self.state_dict()[layer_name]):
|
||||
max_weight_id = len(line_values) - 1
|
||||
for weight_id, weight_value in enumerate(self.state_dict()[layer_name][line_id]):
|
||||
weight_matrix[i] = weight_value.item(), self.normalize(layer_id, max_layer_id), self.normalize(line_id, max_cell_id), self.normalize(weight_id, max_weight_id)
|
||||
i += 1
|
||||
|
||||
return torch.from_numpy(weight_matrix)
|
||||
|
||||
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (np.ndarray, Tensor, list):
|
||||
def self_train(self,
|
||||
training_steps: int,
|
||||
log_step_size: int = 0,
|
||||
learning_rate: float = 0.0004,
|
||||
save_history: bool = True
|
||||
) -> (Tensor, list):
|
||||
""" Training a network to predict its own weights in order to self-replicate. """
|
||||
|
||||
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
|
||||
@ -127,27 +161,30 @@ class Net(nn.Module):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
|
||||
# If it is a soup/mixed env. save weights only at the end of all training steps (aka a soup/mixed epoch)
|
||||
if "soup" not in self.name and "mixed" not in self.name:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# If self-training steps are lower than 10, then append weight history after each ST step.
|
||||
if self.number_trained < 10:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.detach().numpy().item())
|
||||
else:
|
||||
if self.number_trained % log_step_size == 0:
|
||||
if save_history:
|
||||
# Saving the history of the weights after a certain amount of steps (aka log_step_size) for research.
|
||||
# If it is a soup/mixed env. save weights only at the end of all training steps (aka a soup/mixed epoch)
|
||||
if "soup" not in self.name and "mixed" not in self.name:
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# If self-training steps are lower than 10, then append weight history after each ST step.
|
||||
if self.number_trained < 10:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.detach().numpy().item())
|
||||
self.loss_history.append(loss.item())
|
||||
else:
|
||||
if log_step_size != 0:
|
||||
if self.number_trained % log_step_size == 0:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
weights = self.create_target_weights(self.input_weight_matrix())
|
||||
# Saving weights only at the end of a soup/mixed exp. epoch.
|
||||
if "soup" in self.name or "mixed" in self.name:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.detach().numpy().item())
|
||||
if save_history:
|
||||
if "soup" in self.name or "mixed" in self.name:
|
||||
self.s_train_weights_history.append(weights.T.detach().numpy())
|
||||
self.loss_history.append(loss.item())
|
||||
|
||||
self.trained = True
|
||||
return weights.detach().numpy(), loss, self.loss_history
|
||||
return loss, self.loss_history
|
||||
|
||||
def self_application(self, SA_steps: int, log_step_size: Union[int, None] = None):
|
||||
""" Inputting the weights of a network to itself for a number of steps, without backpropagation. """
|
||||
@ -208,7 +245,7 @@ class Net(nn.Module):
|
||||
|
||||
class SecondaryNet(Net):
|
||||
|
||||
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (np.ndarray, Tensor, list):
|
||||
def self_train(self, training_steps: int, log_step_size: int, learning_rate: float) -> (pd.DataFrame, Tensor, list):
|
||||
""" Training a network to predict its own weights in order to self-replicate. """
|
||||
|
||||
optimizer = optim.SGD(self.parameters(), lr=learning_rate, momentum=0.9)
|
||||
@ -245,10 +282,6 @@ class SecondaryNet(Net):
|
||||
return df, is_diverged
|
||||
|
||||
|
||||
class MetaWeight(Net):
|
||||
pass
|
||||
|
||||
|
||||
class MetaCell(nn.Module):
|
||||
def __init__(self, name, interface, residual_skip=True):
|
||||
super().__init__()
|
||||
@ -258,67 +291,102 @@ class MetaCell(nn.Module):
|
||||
self.weight_interface = 4
|
||||
self.net_hidden_size = 4
|
||||
self.net_ouput_size = 1
|
||||
self.meta_weight_list = nn.ModuleList(
|
||||
[MetaWeight(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
) for weight_idx in range(self.interface)])
|
||||
self.meta_weight_list = nn.ModuleList()
|
||||
self.meta_weight_list.extend(
|
||||
[Net(self.weight_interface, self.net_hidden_size,
|
||||
self.net_ouput_size, name=f'{self.name}_{weight_idx}'
|
||||
) for weight_idx in range(self.interface)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
xs = [torch.hstack((x[:, idx].unsqueeze(-1), torch.zeros((x.shape[0], self.weight_interface - 1))))
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
xs = [torch.hstack(
|
||||
(torch.zeros((x.shape[0], self.weight_interface - 1), device=x.device), x[:, idx].unsqueeze(-1))
|
||||
)
|
||||
for idx in range(len(self.meta_weight_list))]
|
||||
tensor = torch.hstack([meta_weight(xs[idx]) for idx, meta_weight in enumerate(self.meta_weight_list)])
|
||||
|
||||
if self.residual_skip:
|
||||
tensor += x
|
||||
|
||||
result = torch.sum(tensor, dim=-1, keepdim=True)
|
||||
return result
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (net for net in self.meta_weight_list)
|
||||
|
||||
|
||||
class MetaLayer(nn.Module):
|
||||
def __init__(self, name, interface=4, out=1, width=4):
|
||||
def __init__(self, name, interface=4, width=4):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
|
||||
meta_cell_list = nn.ModuleList([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)])
|
||||
self.meta_cell_list = meta_cell_list
|
||||
self.meta_cell_list = nn.ModuleList()
|
||||
self.meta_cell_list.extend([MetaCell(name=f'{self.name}_{cell_idx}',
|
||||
interface=interface
|
||||
) for cell_idx in range(self.width)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
result = torch.hstack([metacell(x) for metacell in self.meta_cell_list])
|
||||
return result
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (weight for metacell in self.meta_cell_list for weight in metacell.particles)
|
||||
|
||||
|
||||
class MetaNet(nn.Module):
|
||||
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1):
|
||||
def __init__(self, interface=4, depth=3, width=4, out=1, activation=None):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.out = out
|
||||
self.interface = interface
|
||||
self.width = width
|
||||
self.depth = depth
|
||||
|
||||
meta_layer_list = nn.ModuleList([MetaLayer(name=f'Weight_{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)])
|
||||
meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)])
|
||||
meta_layer_list.append(MetaLayer(name=f'Weight_{len(meta_layer_list)}',
|
||||
interface=self.width, width=self.out))
|
||||
self._meta_layer_list = meta_layer_list
|
||||
self._net = nn.Sequential(*self._meta_layer_list)
|
||||
self._meta_layer_list = nn.ModuleList()
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{0}',
|
||||
interface=self.interface,
|
||||
width=self.width)
|
||||
)
|
||||
self._meta_layer_list.extend([MetaLayer(name=f'Weight_{layer_idx + 1}',
|
||||
interface=self.width, width=self.width
|
||||
) for layer_idx in range(self.depth - 2)]
|
||||
)
|
||||
self._meta_layer_list.append(MetaLayer(name=f'Weight_{len(self._meta_layer_list)}',
|
||||
interface=self.width, width=self.out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
result = self._net.forward(x)
|
||||
return result
|
||||
tensor = x
|
||||
for meta_layer in self._meta_layer_list:
|
||||
tensor = meta_layer(tensor)
|
||||
return tensor
|
||||
|
||||
@property
|
||||
def particles(self):
|
||||
return (cell for metalayer in self._meta_layer_list for cell in metalayer.particles)
|
||||
|
||||
def combined_self_train(self):
|
||||
losses = []
|
||||
for particle in self.particles:
|
||||
# Intergrate optimizer and backward function
|
||||
input_data = particle.input_weight_matrix()
|
||||
target_data = particle.create_target_weights(input_data)
|
||||
output = particle(input_data)
|
||||
losses.append(F.mse_loss(output, target_data))
|
||||
return torch.hstack(losses).sum(dim=-1, keepdim=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
metanet = MetaNet(2, 3, 4, 1)
|
||||
metanet = MetaNet(interface=2, depth=3, width=2, out=1)
|
||||
next(metanet.particles).input_weight_matrix()
|
||||
metanet(torch.ones((5, 2)))
|
||||
a = metanet.particles
|
||||
print('Test')
|
||||
print('Test')
|
||||
print('Test')
|
||||
|
Reference in New Issue
Block a user