sparse meta networt

This commit is contained in:
Maximilian Zorn 2022-02-09 14:35:55 +01:00
parent d4c25872c6
commit 7ae3e96ec9
2 changed files with 552 additions and 115 deletions

181
sparse_net.py Normal file
View File

@ -0,0 +1,181 @@
from network import Net
from typing import List
from functionalities_test import is_identity_function
from tqdm import tqdm,trange
import numpy as np
from pathlib import Path
import torch
from torch.nn import Flatten
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Compose, Resize
class SparseLayer():
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
self.nr_nets = nr_nets
self.interface_dim = interface
self.depth_dim = depth
self.hidden_dim = width
self.out_dim = out
self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)
self.sparse_sub_layer = []
self.weights = []
for layer_id in range(depth):
layer, weights = self.coo_sparse_layer(layer_id)
self.sparse_sub_layer.append(layer)
self.weights.append(weights)
def coo_sparse_layer(self, layer_id):
layer_shape = list(self.dummy_net.parameters())[layer_id].shape
#print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)
sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)
indices = np.argwhere(sparse_diagonal == 1).T
values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) )))
#values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] ))
s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True)
print(f"L{layer_id}:", s.shape)
return s, values
def get_self_train_inputs_and_targets(self):
encoding_matrix, mask = self.dummy_net._weight_pos_enc
# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net
weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]]
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]]
inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25)
targets = torch.hstack(weights_per_net)
return inputs.T, targets.T
def __call__(self, x):
X1 = torch.sparse.mm(self.sparse_sub_layer[0], x)
#print("X1", X1.shape)
X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1)
#print("X2", X2.shape)
X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2)
#print("X3", X3.shape)
return X3
def test_sparse_layer():
net = SparseLayer(500) #50 parallel nets
loss_fn = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9)
#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
for train_iteration in trange(1000):
optimizer.zero_grad()
X,Y = net.get_self_train_inputs_and_targets()
out = net(X)
loss = loss_fn(out, Y)
# print("X:", X.shape, "Y:", Y.shape)
# print("OUT", out.shape)
# print("LOSS", loss.item())
loss.backward(retain_graph=True)
optimizer.step()
epsilon=pow(10, -5)
# is each of the networks self-replicating?
print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}")
def embed_batch(x, repeat_dim):
# x of shape (batchsize, flat_img_dim)
x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)
return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)
def embed_vector(x, repeat_dim):
# x of shape [flat_img_dim]
x = x.unsqueeze(-1) #(flat_img_dim, 1)
return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim)
class SparseNetwork():
def __init__(self, input_dim, depth, width, out):
self.input_dim = input_dim
self.depth_dim = depth
self.hidden_dim = width
self.out_dim = out
self.sparse_layers = []
self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim ))
self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)])
self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim ))
def __call__(self, x):
for sparse_layer in self.sparse_layers[:-1]:
# batch pass (one by one, sparse bmm doesn't support grad)
if len(x.shape) > 1:
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]
# vector
else:
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1)
print("out", x.shape)
# output layer
sparse_layer = self.sparse_layers[-1]
if len(x.shape) > 1:
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]
else:
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1)
print("out", x.shape)
return x
def test_sparse_net():
utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])
data_path = Path('data')
WORKER = 8
BATCHSIZE = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = MNIST(str(data_path), transform=utility_transforms)
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
data_dim = np.prod(dataset[0][0].shape)
metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)
batchx, batchy = next(iter(d))
batchx.shape, batchy.shape
metanet(batchx)
def test_manual_for_loop():
nr_nets = 500
nets = [Net(5,2,1) for _ in range(nr_nets)]
loss_fn = torch.nn.MSELoss(reduction="sum")
rounds = 1000
for net in tqdm(nets):
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
for i in range(rounds):
optimizer.zero_grad()
input_data = net.input_weight_matrix()
target_data = net.create_target_weights(input_data)
output = net(input_data)
loss = loss_fn(output, target_data)
loss.backward()
optimizer.step()
sum([is_identity_function(net) for net in nets])
if __name__ == '__main__':
test_sparse_layer()
test_sparse_net()
#for comparison
test_manual_for_loop()

View File

@ -2,116 +2,25 @@
"cells": [
{
"cell_type": "code",
"execution_count": 222,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from network import Net\n",
"import torch\n",
"from typing import List\n",
"from functionalities_test import is_identity_function"
"from functionalities_test import is_identity_function\n",
"from tqdm import tqdm,trange\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 255,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nr_nets = 5\n",
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
"\n",
"loss_fn = torch.nn.MSELoss()\n",
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 256,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 256,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum([is_identity_function(net) for net in nets])"
]
},
{
"cell_type": "code",
"execution_count": 247,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([14, 20]), torch.Size([14, 5]))"
]
},
"execution_count": 247,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.hstack( [net.input_weight_matrix() for net in nets] ) #(nr_nets*nr_weights, nr_weights)\n",
"Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ) #(nr_nets*nr_weights,1)\n",
"X.shape, Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 270,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor(indices=tensor([[ 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3,\n",
" 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6,\n",
" 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9],\n",
" [ 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5,\n",
" 6, 7, 8, 9, 10, 11, 8, 9, 10, 11, 12, 13, 14, 15,\n",
" 12, 13, 14, 15, 16, 17, 18, 19, 16, 17, 18, 19]]),\n",
" values=tensor([-0.0282, 0.1612, 0.0717, -0.1370, -0.0789, 0.0990,\n",
" -0.0642, 0.1385, -0.1046, 0.1522, -0.0691, 0.0848,\n",
" -0.1419, -0.0465, -0.0385, 0.1453, -0.0263, 0.1401,\n",
" 0.0758, 0.1022, 0.1218, -0.1423, 0.0556, 0.0150,\n",
" 0.0598, -0.0347, -0.0717, 0.1173, 0.0126, -0.0164,\n",
" -0.0359, 0.0895, 0.1545, -0.1091, 0.0925, 0.0687,\n",
" 0.1330, 0.1297, 0.0305, 0.1811]),\n",
" size=(10, 20), nnz=40, layout=torch.sparse_coo, requires_grad=True),\n",
" tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9,\n",
" 9],\n",
" [0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, 8,\n",
" 9]]),\n",
" values=tensor([-0.1608, 0.0952, 0.0369, 0.0105, -0.0277, 0.0216,\n",
" 0.0991, 0.1250, 0.0618, 0.2241, 0.0602, 0.1144,\n",
" -0.0330, -0.1240, 0.0456, -0.1208, -0.1859, 0.1333,\n",
" 0.1235, -0.1774]),\n",
" size=(10, 10), nnz=20, layout=torch.sparse_coo, requires_grad=True),\n",
" tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4],\n",
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),\n",
" values=tensor([ 0.0585, 0.1856, -0.0987, 0.2342, -0.0376, 0.0765,\n",
" -0.1395, 0.1574, -0.0103, -0.0933]),\n",
" size=(5, 10), nnz=10, layout=torch.sparse_coo, requires_grad=True)]"
]
},
"execution_count": 270,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def construct_sparse_tensor_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
"def construct_sparse_COO_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
" assert layer_idx <= len(list(nets[0].parameters()))\n",
" values = []\n",
" indices = []\n",
@ -120,16 +29,20 @@
" \n",
" for cell_idx,cell in enumerate(layer):\n",
" # E.g., position of cell weights (with 2 cells per hidden layer) in first sparse layer of N nets: \n",
" \n",
" # [4x2 weights_net0] [4x2x(n-1) 0s]\n",
" # [4x2 weights] [4x2 weights_net0] [4x2x(n-2) 0s]\n",
" # ... etc\n",
" # [4x2x(n-1) 0s] [4x2 weights_netN]\n",
" \n",
" \n",
" # -> 4x2 weights on the diagonal = [shifted Nr_cellss*B down for AxB cells, and Nr_nets(*A weights)to the right] \n",
" for i in range(len(cell)):\n",
" indices.append([len(layer)*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
" #indices.append([2*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
"\n",
" [values.append(weight) for weight in cell]\n",
"\n",
" # for i in range(4):\n",
" # indices.append([idx+idx+1, i+(idx*4)])\n",
" #for l in next(net.parameters()):\n",
@ -141,7 +54,7 @@
" # layer 1: (2x4) -> (2*N, 4*N)\n",
" # layer 2: (2x2) -> (2*N, 2*N)\n",
" # layer 3: (1x2) -> (2*N, 1*N)\n",
" s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets),requires_grad=True)\n",
" s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets))\n",
" #print(s.to_dense())\n",
" #print(s.to_dense().shape)\n",
" return s\n",
@ -152,42 +65,37 @@
"# - [4x2] weights in the first (input) layer\n",
"# - [2x2] weights in the second (hidden) layer\n",
"# - [2x1] weights in the third (output) layer\n",
"modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
"modules\n",
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
"#modules\n",
"#for layer_idx in range(len(list(nets[0].parameters()))):\n",
"# sparse_tensor = construct_sparse_tensor_layer(nets, layer_idx)"
]
},
{
"cell_type": "code",
"execution_count": 295,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"before: 0/5000 identity_fns\n",
"after 1 iterations of combined self_train: 0/5000 identity_fns\n"
]
}
],
"outputs": [],
"source": [
"nr_nets = 5000\n",
"nr_nets = 50\n",
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
"\n",
"modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
"\n",
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
"\n",
"\n",
"for train_iteration in range(1):\n",
"for train_iteration in range(1000):\n",
" optimizer.zero_grad() \n",
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
"\n",
" modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
" modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
"\n",
" X1 = torch.sparse.mm(modules[0], X)\n",
" #print(\"X1\", X1.shape, X1)\n",
@ -211,7 +119,355 @@
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"nr_nets = 500\n",
"nets = [Net(5,2,1) for _ in range(nr_nets)]\n",
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
"rounds = 1000\n",
"\n",
"for net in tqdm(nets):\n",
" optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)\n",
" for i in range(rounds):\n",
" optimizer.zero_grad()\n",
" input_data = net.input_weight_matrix()\n",
" target_data = net.create_target_weights(input_data)\n",
" output = net(input_data)\n",
" loss = loss_fn(output, target_data)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"sum([is_identity_function(net) for net in nets])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def construct_sparse_CRS_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
" assert layer_idx <= len(list(nets[0].parameters()))\n",
" \n",
" s = torch.cat( [\n",
" torch.cat(\n",
" (\n",
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
" list(net.parameters())[layer_idx], \n",
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
" )\n",
" , dim=1) for net_idx, net in enumerate(nets)\n",
" ]).to_sparse_csr()\n",
"\n",
" print(s.shape)\n",
" return s"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nr_nets = 5\n",
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
"\n",
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
"\n",
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
"\n",
"\n",
"for train_iteration in range(1):\n",
" optimizer.zero_grad() \n",
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
"\n",
" num_layers = len(list(nets[0].parameters()))\n",
" modules = [ construct_sparse_CRS_layer(nets, layer_idx) for layer_idx in range(num_layers)]\n",
"\n",
" X1 = modules[0].matmul(X)\n",
" print(\"X1\", X1.shape, X1.is_sparse)\n",
"\n",
" X2 = modules[1].matmul(X1)\n",
" print(\"X2\", X2.shape, X2.is_sparse)\n",
"\n",
" X3 = modules[2].matmul(X2)\n",
" print(\"X3\", X3.shape, X3.is_sparse)\n",
"\n",
" loss = loss_fn(X3, Y)\n",
" #print(loss)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nr_nets = 2\n",
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
"\n",
"def cat_COO_layer(nets, layer_idx):\n",
" i = [[0,i] for i in range(nr_nets*len(list(net.parameters())[layer_idx]))]\n",
" v = torch.cat( [\n",
" torch.cat(\n",
" (\n",
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
" list(net.parameters())[layer_idx], \n",
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
" )\n",
" , dim=1) for net_idx, net in enumerate(nets)\n",
" ])\n",
" #print(i,v)\n",
" s = torch.sparse_coo_tensor(list(zip(*i)), v)\n",
" print(s[0].to_dense().shape, s[0].is_sparse)\n",
" return s[0]\n",
"\n",
"cat_COO_layer(nets, 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nr_nets = 5\n",
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
"\n",
"nr_layers = len(list(nets[0].parameters()))\n",
"modules = [ cat_COO_layer(nets, layer_idx) for layer_idx in range(nr_layers) ]\n",
"\n",
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
"\n",
"\n",
"for train_iteration in range(1):\n",
" optimizer.zero_grad() \n",
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
" print(\"X \", X.shape, \"Y\", Y.shape)\n",
"\n",
" X1 = torch.sparse.mm(modules[0], X)\n",
" print(\"X1\", X1.shape)\n",
"\n",
" X2 = torch.sparse.mm(modules[1], X1)\n",
" print(\"X2\", X2.shape)\n",
"\n",
" X3 = torch.sparse.mm(modules[2], X2)\n",
" print(\"X3\", X3.shape)\n",
"\n",
" loss = loss_fn(X3, Y)\n",
" #print(loss)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SparseLayer():\n",
" def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):\n",
" self.nr_nets = nr_nets\n",
" self.interface_dim = interface\n",
" self.depth_dim = depth\n",
" self.hidden_dim = width\n",
" self.out_dim = out\n",
" self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)\n",
" \n",
" self.sparse_sub_layer = []\n",
" self.weights = []\n",
" for layer_id in range(depth):\n",
" layer, weights = self.coo_sparse_layer(layer_id)\n",
" self.sparse_sub_layer.append(layer)\n",
" self.weights.append(weights)\n",
" \n",
" def coo_sparse_layer(self, layer_id):\n",
" layer_shape = list(self.dummy_net.parameters())[layer_id].shape\n",
" #print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)\n",
"\n",
" sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)\n",
" indices = np.argwhere(sparse_diagonal == 1).T\n",
" values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) )))\n",
" #values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] ))\n",
" s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True)\n",
" print(f\"L{layer_id}:\", s.shape)\n",
" return s, values\n",
"\n",
" def get_self_train_inputs_and_targets(self):\n",
" encoding_matrix, mask = self.dummy_net._weight_pos_enc\n",
"\n",
" # view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
" # i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
" weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]]\n",
" weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]]\n",
" inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25)\n",
" targets = torch.hstack(weights_per_net)\n",
" return inputs.T, targets.T\n",
"\n",
" def __call__(self, x):\n",
" X1 = torch.sparse.mm(self.sparse_sub_layer[0], x)\n",
" #print(\"X1\", X1.shape)\n",
"\n",
" X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1)\n",
" #print(\"X2\", X2.shape)\n",
"\n",
" X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2)\n",
" #print(\"X3\", X3.shape)\n",
" \n",
" return X3\n",
"\n",
"net = SparseLayer(5)\n",
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
"optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9)\n",
"#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)\n",
"\n",
"for train_iteration in trange(10):\n",
" optimizer.zero_grad() \n",
" X,Y = net.get_self_train_inputs_and_targets()\n",
" out = net(X)\n",
" \n",
" loss = loss_fn(out, Y)\n",
"\n",
" # print(\"X:\", X.shape, \"Y:\", Y.shape)\n",
" # print(\"OUT\", out.shape)\n",
" # print(\"LOSS\", loss.item())\n",
" \n",
" loss.backward(retain_graph=True)\n",
" optimizer.step()\n",
"\n",
" \n",
"\n",
"epsilon=pow(10, -5)\n",
"# is the (the whole layer) self-replicating? -> wrong\n",
"#print(torch.allclose(out, Y,rtol=0, atol=epsilon))\n",
"\n",
"# is each of the networks self-replicating?\n",
"print(f\"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for layer in net.weights:\n",
"# n=int(len(layer)/net.nr_nets)\n",
"# print( [layer[i:i+n] for i in range(0, len(layer), n)])\n",
"\n",
"encoding_matrix, mask = Net(5,2,1)._weight_pos_enc\n",
"print(encoding_matrix, mask)\n",
"# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
"# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
"weights = [layer.view(-1, int(len(layer)/net.nr_nets)) for layer in net.weights]\n",
"weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(net.nr_nets)]\n",
"\n",
"inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(net.nr_nets)]) #16, 25\n",
"\n",
"targets = torch.hstack(weights_per_net)\n",
"targets.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from pathlib import Path\n",
"import torch\n",
"from torch.nn import Flatten\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torchvision.datasets import MNIST\n",
"from torchvision.transforms import ToTensor, Compose, Resize\n",
"from tqdm import tqdm, trange\n",
"\n",
"\n",
"utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])\n",
"data_path = Path('data')\n",
"WORKER = 8\n",
"BATCHSIZE = 10\n",
"EPOCH = 1\n",
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"dataset = MNIST(str(data_path), transform=utility_transforms)\n",
"d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def embed_batch(x, repeat_dim):\n",
" # x of shape (batchsize, flat_img_dim)\n",
" x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)\n",
" return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)\n",
"\n",
"def embed_vector(x, repeat_dim):\n",
" # x of shape [flat_img_dim]\n",
" x = x.unsqueeze(-1) #(flat_img_dim, 1)\n",
" return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim)\n",
"\n",
"class SparseNetwork():\n",
" def __init__(self, input_dim, depth, width, out):\n",
" self.input_dim = input_dim\n",
" self.depth_dim = depth\n",
" self.hidden_dim = width\n",
" self.out_dim = out\n",
" self.sparse_layers = []\n",
" self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim ))\n",
" self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)])\n",
" self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim ))\n",
"\n",
" def __call__(self, x):\n",
" \n",
" for sparse_layer in self.sparse_layers[:-1]:\n",
" # batch pass (one by one, sparse bmm doesn't support grad)\n",
" if len(x.shape) > 1:\n",
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
" # vector\n",
" else:\n",
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1)\n",
" print(\"out\", x.shape)\n",
" \n",
" # output layer\n",
" sparse_layer = self.sparse_layers[-1]\n",
" if len(x.shape) > 1:\n",
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
" else:\n",
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1)\n",
" print(\"out\", x.shape)\n",
" return x\n",
"\n",
"data_dim = np.prod(dataset[0][0].shape)\n",
"metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)\n",
"batchx, batchy = next(iter(d))\n",
"batchx.shape, batchy.shape\n",
"metanet(batchx)"
]
}
],
"metadata": {