sparse meta networt
This commit is contained in:
parent
d4c25872c6
commit
7ae3e96ec9
181
sparse_net.py
Normal file
181
sparse_net.py
Normal file
@ -0,0 +1,181 @@
|
||||
from network import Net
|
||||
from typing import List
|
||||
from functionalities_test import is_identity_function
|
||||
from tqdm import tqdm,trange
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.nn import Flatten
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision.transforms import ToTensor, Compose, Resize
|
||||
|
||||
|
||||
class SparseLayer():
|
||||
def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):
|
||||
self.nr_nets = nr_nets
|
||||
self.interface_dim = interface
|
||||
self.depth_dim = depth
|
||||
self.hidden_dim = width
|
||||
self.out_dim = out
|
||||
self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)
|
||||
|
||||
self.sparse_sub_layer = []
|
||||
self.weights = []
|
||||
for layer_id in range(depth):
|
||||
layer, weights = self.coo_sparse_layer(layer_id)
|
||||
self.sparse_sub_layer.append(layer)
|
||||
self.weights.append(weights)
|
||||
|
||||
def coo_sparse_layer(self, layer_id):
|
||||
layer_shape = list(self.dummy_net.parameters())[layer_id].shape
|
||||
#print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)
|
||||
|
||||
sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)
|
||||
indices = np.argwhere(sparse_diagonal == 1).T
|
||||
values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) )))
|
||||
#values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] ))
|
||||
s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True)
|
||||
print(f"L{layer_id}:", s.shape)
|
||||
return s, values
|
||||
|
||||
def get_self_train_inputs_and_targets(self):
|
||||
encoding_matrix, mask = self.dummy_net._weight_pos_enc
|
||||
|
||||
# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN
|
||||
# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net
|
||||
weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]]
|
||||
weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]]
|
||||
inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25)
|
||||
targets = torch.hstack(weights_per_net)
|
||||
return inputs.T, targets.T
|
||||
|
||||
def __call__(self, x):
|
||||
X1 = torch.sparse.mm(self.sparse_sub_layer[0], x)
|
||||
#print("X1", X1.shape)
|
||||
|
||||
X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1)
|
||||
#print("X2", X2.shape)
|
||||
|
||||
X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2)
|
||||
#print("X3", X3.shape)
|
||||
|
||||
return X3
|
||||
|
||||
|
||||
def test_sparse_layer():
|
||||
net = SparseLayer(500) #50 parallel nets
|
||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||
optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9)
|
||||
#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)
|
||||
|
||||
for train_iteration in trange(1000):
|
||||
optimizer.zero_grad()
|
||||
X,Y = net.get_self_train_inputs_and_targets()
|
||||
out = net(X)
|
||||
|
||||
loss = loss_fn(out, Y)
|
||||
|
||||
# print("X:", X.shape, "Y:", Y.shape)
|
||||
# print("OUT", out.shape)
|
||||
# print("LOSS", loss.item())
|
||||
|
||||
loss.backward(retain_graph=True)
|
||||
optimizer.step()
|
||||
|
||||
epsilon=pow(10, -5)
|
||||
# is each of the networks self-replicating?
|
||||
print(f"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def embed_batch(x, repeat_dim):
|
||||
# x of shape (batchsize, flat_img_dim)
|
||||
x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)
|
||||
return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)
|
||||
|
||||
def embed_vector(x, repeat_dim):
|
||||
# x of shape [flat_img_dim]
|
||||
x = x.unsqueeze(-1) #(flat_img_dim, 1)
|
||||
return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim)
|
||||
|
||||
class SparseNetwork():
|
||||
def __init__(self, input_dim, depth, width, out):
|
||||
self.input_dim = input_dim
|
||||
self.depth_dim = depth
|
||||
self.hidden_dim = width
|
||||
self.out_dim = out
|
||||
self.sparse_layers = []
|
||||
self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim ))
|
||||
self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)])
|
||||
self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim ))
|
||||
|
||||
def __call__(self, x):
|
||||
|
||||
for sparse_layer in self.sparse_layers[:-1]:
|
||||
# batch pass (one by one, sparse bmm doesn't support grad)
|
||||
if len(x.shape) > 1:
|
||||
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
|
||||
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]
|
||||
# vector
|
||||
else:
|
||||
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
|
||||
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1)
|
||||
print("out", x.shape)
|
||||
|
||||
# output layer
|
||||
sparse_layer = self.sparse_layers[-1]
|
||||
if len(x.shape) > 1:
|
||||
embedded_inpt = embed_batch(x, sparse_layer.nr_nets)
|
||||
x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]
|
||||
else:
|
||||
embedded_inpt = embed_vector(x, sparse_layer.nr_nets)
|
||||
x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1)
|
||||
print("out", x.shape)
|
||||
return x
|
||||
|
||||
|
||||
def test_sparse_net():
|
||||
utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])
|
||||
data_path = Path('data')
|
||||
WORKER = 8
|
||||
BATCHSIZE = 10
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
dataset = MNIST(str(data_path), transform=utility_transforms)
|
||||
d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)
|
||||
|
||||
data_dim = np.prod(dataset[0][0].shape)
|
||||
metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)
|
||||
batchx, batchy = next(iter(d))
|
||||
batchx.shape, batchy.shape
|
||||
metanet(batchx)
|
||||
|
||||
|
||||
def test_manual_for_loop():
|
||||
nr_nets = 500
|
||||
nets = [Net(5,2,1) for _ in range(nr_nets)]
|
||||
loss_fn = torch.nn.MSELoss(reduction="sum")
|
||||
rounds = 1000
|
||||
|
||||
for net in tqdm(nets):
|
||||
optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)
|
||||
for i in range(rounds):
|
||||
optimizer.zero_grad()
|
||||
input_data = net.input_weight_matrix()
|
||||
target_data = net.create_target_weights(input_data)
|
||||
output = net(input_data)
|
||||
loss = loss_fn(output, target_data)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
sum([is_identity_function(net) for net in nets])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sparse_layer()
|
||||
test_sparse_net()
|
||||
#for comparison
|
||||
test_manual_for_loop()
|
@ -2,116 +2,25 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 222,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from network import Net\n",
|
||||
"import torch\n",
|
||||
"from typing import List\n",
|
||||
"from functionalities_test import is_identity_function"
|
||||
"from functionalities_test import is_identity_function\n",
|
||||
"from tqdm import tqdm,trange\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 255,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss()\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 256,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0"
|
||||
]
|
||||
},
|
||||
"execution_count": 256,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sum([is_identity_function(net) for net in nets])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 247,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(torch.Size([14, 20]), torch.Size([14, 5]))"
|
||||
]
|
||||
},
|
||||
"execution_count": 247,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X = torch.hstack( [net.input_weight_matrix() for net in nets] ) #(nr_nets*nr_weights, nr_weights)\n",
|
||||
"Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ) #(nr_nets*nr_weights,1)\n",
|
||||
"X.shape, Y.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 270,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[tensor(indices=tensor([[ 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3,\n",
|
||||
" 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6,\n",
|
||||
" 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9],\n",
|
||||
" [ 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5,\n",
|
||||
" 6, 7, 8, 9, 10, 11, 8, 9, 10, 11, 12, 13, 14, 15,\n",
|
||||
" 12, 13, 14, 15, 16, 17, 18, 19, 16, 17, 18, 19]]),\n",
|
||||
" values=tensor([-0.0282, 0.1612, 0.0717, -0.1370, -0.0789, 0.0990,\n",
|
||||
" -0.0642, 0.1385, -0.1046, 0.1522, -0.0691, 0.0848,\n",
|
||||
" -0.1419, -0.0465, -0.0385, 0.1453, -0.0263, 0.1401,\n",
|
||||
" 0.0758, 0.1022, 0.1218, -0.1423, 0.0556, 0.0150,\n",
|
||||
" 0.0598, -0.0347, -0.0717, 0.1173, 0.0126, -0.0164,\n",
|
||||
" -0.0359, 0.0895, 0.1545, -0.1091, 0.0925, 0.0687,\n",
|
||||
" 0.1330, 0.1297, 0.0305, 0.1811]),\n",
|
||||
" size=(10, 20), nnz=40, layout=torch.sparse_coo, requires_grad=True),\n",
|
||||
" tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9,\n",
|
||||
" 9],\n",
|
||||
" [0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, 8,\n",
|
||||
" 9]]),\n",
|
||||
" values=tensor([-0.1608, 0.0952, 0.0369, 0.0105, -0.0277, 0.0216,\n",
|
||||
" 0.0991, 0.1250, 0.0618, 0.2241, 0.0602, 0.1144,\n",
|
||||
" -0.0330, -0.1240, 0.0456, -0.1208, -0.1859, 0.1333,\n",
|
||||
" 0.1235, -0.1774]),\n",
|
||||
" size=(10, 10), nnz=20, layout=torch.sparse_coo, requires_grad=True),\n",
|
||||
" tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4],\n",
|
||||
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),\n",
|
||||
" values=tensor([ 0.0585, 0.1856, -0.0987, 0.2342, -0.0376, 0.0765,\n",
|
||||
" -0.1395, 0.1574, -0.0103, -0.0933]),\n",
|
||||
" size=(5, 10), nnz=10, layout=torch.sparse_coo, requires_grad=True)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 270,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def construct_sparse_tensor_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
||||
"def construct_sparse_COO_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
||||
" assert layer_idx <= len(list(nets[0].parameters()))\n",
|
||||
" values = []\n",
|
||||
" indices = []\n",
|
||||
@ -120,16 +29,20 @@
|
||||
" \n",
|
||||
" for cell_idx,cell in enumerate(layer):\n",
|
||||
" # E.g., position of cell weights (with 2 cells per hidden layer) in first sparse layer of N nets: \n",
|
||||
" \n",
|
||||
" # [4x2 weights_net0] [4x2x(n-1) 0s]\n",
|
||||
" # [4x2 weights] [4x2 weights_net0] [4x2x(n-2) 0s]\n",
|
||||
" # ... etc\n",
|
||||
" # [4x2x(n-1) 0s] [4x2 weights_netN]\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # -> 4x2 weights on the diagonal = [shifted Nr_cellss*B down for AxB cells, and Nr_nets(*A weights)to the right] \n",
|
||||
" for i in range(len(cell)):\n",
|
||||
" indices.append([len(layer)*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
||||
" #indices.append([2*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
||||
"\n",
|
||||
" [values.append(weight) for weight in cell]\n",
|
||||
"\n",
|
||||
" # for i in range(4):\n",
|
||||
" # indices.append([idx+idx+1, i+(idx*4)])\n",
|
||||
" #for l in next(net.parameters()):\n",
|
||||
@ -141,7 +54,7 @@
|
||||
" # layer 1: (2x4) -> (2*N, 4*N)\n",
|
||||
" # layer 2: (2x2) -> (2*N, 2*N)\n",
|
||||
" # layer 3: (1x2) -> (2*N, 1*N)\n",
|
||||
" s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets),requires_grad=True)\n",
|
||||
" s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets))\n",
|
||||
" #print(s.to_dense())\n",
|
||||
" #print(s.to_dense().shape)\n",
|
||||
" return s\n",
|
||||
@ -152,42 +65,37 @@
|
||||
"# - [4x2] weights in the first (input) layer\n",
|
||||
"# - [2x2] weights in the second (hidden) layer\n",
|
||||
"# - [2x1] weights in the third (output) layer\n",
|
||||
"modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"modules\n",
|
||||
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"#modules\n",
|
||||
"#for layer_idx in range(len(list(nets[0].parameters()))):\n",
|
||||
"# sparse_tensor = construct_sparse_tensor_layer(nets, layer_idx)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 295,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"before: 0/5000 identity_fns\n",
|
||||
"after 1 iterations of combined self_train: 0/5000 identity_fns\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5000\n",
|
||||
"nr_nets = 50\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1):\n",
|
||||
"for train_iteration in range(1000):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
" modules = [ construct_sparse_COO_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"\n",
|
||||
" X1 = torch.sparse.mm(modules[0], X)\n",
|
||||
" #print(\"X1\", X1.shape, X1)\n",
|
||||
@ -211,7 +119,355 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"nr_nets = 500\n",
|
||||
"nets = [Net(5,2,1) for _ in range(nr_nets)]\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"rounds = 1000\n",
|
||||
"\n",
|
||||
"for net in tqdm(nets):\n",
|
||||
" optimizer = torch.optim.SGD(net.parameters(), lr=0.004, momentum=0.9)\n",
|
||||
" for i in range(rounds):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" input_data = net.input_weight_matrix()\n",
|
||||
" target_data = net.create_target_weights(input_data)\n",
|
||||
" output = net(input_data)\n",
|
||||
" loss = loss_fn(output, target_data)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"sum([is_identity_function(net) for net in nets])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def construct_sparse_CRS_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
||||
" assert layer_idx <= len(list(nets[0].parameters()))\n",
|
||||
" \n",
|
||||
" s = torch.cat( [\n",
|
||||
" torch.cat(\n",
|
||||
" (\n",
|
||||
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
|
||||
" list(net.parameters())[layer_idx], \n",
|
||||
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
|
||||
" )\n",
|
||||
" , dim=1) for net_idx, net in enumerate(nets)\n",
|
||||
" ]).to_sparse_csr()\n",
|
||||
"\n",
|
||||
" print(s.shape)\n",
|
||||
" return s"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"#modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
||||
"print( id(list(nets[0].parameters())[0][0,0]) == id(modules[0][0,0]))\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" num_layers = len(list(nets[0].parameters()))\n",
|
||||
" modules = [ construct_sparse_CRS_layer(nets, layer_idx) for layer_idx in range(num_layers)]\n",
|
||||
"\n",
|
||||
" X1 = modules[0].matmul(X)\n",
|
||||
" print(\"X1\", X1.shape, X1.is_sparse)\n",
|
||||
"\n",
|
||||
" X2 = modules[1].matmul(X1)\n",
|
||||
" print(\"X2\", X2.shape, X2.is_sparse)\n",
|
||||
"\n",
|
||||
" X3 = modules[2].matmul(X2)\n",
|
||||
" print(\"X3\", X3.shape, X3.is_sparse)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 2\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"\n",
|
||||
"def cat_COO_layer(nets, layer_idx):\n",
|
||||
" i = [[0,i] for i in range(nr_nets*len(list(net.parameters())[layer_idx]))]\n",
|
||||
" v = torch.cat( [\n",
|
||||
" torch.cat(\n",
|
||||
" (\n",
|
||||
" torch.zeros(( len(list(net.parameters())[layer_idx]) ,len(list(net.parameters())[layer_idx][0])*net_idx)), \n",
|
||||
" list(net.parameters())[layer_idx], \n",
|
||||
" torch.zeros((len(list(net.parameters())[layer_idx]), len(list(net.parameters())[layer_idx][0])*(len(nets)-(net_idx+1))))\n",
|
||||
" )\n",
|
||||
" , dim=1) for net_idx, net in enumerate(nets)\n",
|
||||
" ])\n",
|
||||
" #print(i,v)\n",
|
||||
" s = torch.sparse_coo_tensor(list(zip(*i)), v)\n",
|
||||
" print(s[0].to_dense().shape, s[0].is_sparse)\n",
|
||||
" return s[0]\n",
|
||||
"\n",
|
||||
"cat_COO_layer(nets, 0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nr_nets = 5\n",
|
||||
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
||||
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
||||
"\n",
|
||||
"nr_layers = len(list(nets[0].parameters()))\n",
|
||||
"modules = [ cat_COO_layer(nets, layer_idx) for layer_idx in range(nr_layers) ]\n",
|
||||
"\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([module for module in modules], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for train_iteration in range(1):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
||||
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
||||
" print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
||||
"\n",
|
||||
" X1 = torch.sparse.mm(modules[0], X)\n",
|
||||
" print(\"X1\", X1.shape)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(modules[1], X1)\n",
|
||||
" print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(modules[2], X2)\n",
|
||||
" print(\"X3\", X3.shape)\n",
|
||||
"\n",
|
||||
" loss = loss_fn(X3, Y)\n",
|
||||
" #print(loss)\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SparseLayer():\n",
|
||||
" def __init__(self, nr_nets, interface=5, depth=3, width=2, out=1):\n",
|
||||
" self.nr_nets = nr_nets\n",
|
||||
" self.interface_dim = interface\n",
|
||||
" self.depth_dim = depth\n",
|
||||
" self.hidden_dim = width\n",
|
||||
" self.out_dim = out\n",
|
||||
" self.dummy_net = Net(self.interface_dim, self.hidden_dim, self.out_dim)\n",
|
||||
" \n",
|
||||
" self.sparse_sub_layer = []\n",
|
||||
" self.weights = []\n",
|
||||
" for layer_id in range(depth):\n",
|
||||
" layer, weights = self.coo_sparse_layer(layer_id)\n",
|
||||
" self.sparse_sub_layer.append(layer)\n",
|
||||
" self.weights.append(weights)\n",
|
||||
" \n",
|
||||
" def coo_sparse_layer(self, layer_id):\n",
|
||||
" layer_shape = list(self.dummy_net.parameters())[layer_id].shape\n",
|
||||
" #print(layer_shape) #(out_cells, in_cells) -> (2,5), (2,2), (1,2)\n",
|
||||
"\n",
|
||||
" sparse_diagonal = np.eye(self.nr_nets).repeat(layer_shape[0], axis=-2).repeat(layer_shape[1], axis=-1)\n",
|
||||
" indices = np.argwhere(sparse_diagonal == 1).T\n",
|
||||
" values = torch.nn.Parameter(torch.randn((self.nr_nets * (layer_shape[0]*layer_shape[1]) )))\n",
|
||||
" #values = torch.randn((self.nr_nets * layer_shape[0]*layer_shape[1] ))\n",
|
||||
" s = torch.sparse_coo_tensor(indices, values, sparse_diagonal.shape, requires_grad=True)\n",
|
||||
" print(f\"L{layer_id}:\", s.shape)\n",
|
||||
" return s, values\n",
|
||||
"\n",
|
||||
" def get_self_train_inputs_and_targets(self):\n",
|
||||
" encoding_matrix, mask = self.dummy_net._weight_pos_enc\n",
|
||||
"\n",
|
||||
" # view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
|
||||
" # i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
|
||||
" weights = [layer.view(-1, int(len(layer)/self.nr_nets)) for layer in self.weights] #[nr_layers*[nr_net*nr_weights_layer_i]]\n",
|
||||
" weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(self.nr_nets)] #[nr_net*[nr_weights]]\n",
|
||||
" inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(self.nr_nets)]) #(16, 25)\n",
|
||||
" targets = torch.hstack(weights_per_net)\n",
|
||||
" return inputs.T, targets.T\n",
|
||||
"\n",
|
||||
" def __call__(self, x):\n",
|
||||
" X1 = torch.sparse.mm(self.sparse_sub_layer[0], x)\n",
|
||||
" #print(\"X1\", X1.shape)\n",
|
||||
"\n",
|
||||
" X2 = torch.sparse.mm(self.sparse_sub_layer[1], X1)\n",
|
||||
" #print(\"X2\", X2.shape)\n",
|
||||
"\n",
|
||||
" X3 = torch.sparse.mm(self.sparse_sub_layer[2], X2)\n",
|
||||
" #print(\"X3\", X3.shape)\n",
|
||||
" \n",
|
||||
" return X3\n",
|
||||
"\n",
|
||||
"net = SparseLayer(5)\n",
|
||||
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
||||
"optimizer = torch.optim.SGD([weight for weight in net.weights], lr=0.004, momentum=0.9)\n",
|
||||
"#optimizer = torch.optim.SGD([layer for layer in net.sparse_sub_layer], lr=0.004, momentum=0.9)\n",
|
||||
"\n",
|
||||
"for train_iteration in trange(10):\n",
|
||||
" optimizer.zero_grad() \n",
|
||||
" X,Y = net.get_self_train_inputs_and_targets()\n",
|
||||
" out = net(X)\n",
|
||||
" \n",
|
||||
" loss = loss_fn(out, Y)\n",
|
||||
"\n",
|
||||
" # print(\"X:\", X.shape, \"Y:\", Y.shape)\n",
|
||||
" # print(\"OUT\", out.shape)\n",
|
||||
" # print(\"LOSS\", loss.item())\n",
|
||||
" \n",
|
||||
" loss.backward(retain_graph=True)\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"epsilon=pow(10, -5)\n",
|
||||
"# is the (the whole layer) self-replicating? -> wrong\n",
|
||||
"#print(torch.allclose(out, Y,rtol=0, atol=epsilon))\n",
|
||||
"\n",
|
||||
"# is each of the networks self-replicating?\n",
|
||||
"print(f\"identity_fn after {train_iteration+1} self-train iterations: {sum([torch.allclose(out[i], Y[i], rtol=0, atol=epsilon) for i in range(net.nr_nets)])}/{net.nr_nets}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# for layer in net.weights:\n",
|
||||
"# n=int(len(layer)/net.nr_nets)\n",
|
||||
"# print( [layer[i:i+n] for i in range(0, len(layer), n)])\n",
|
||||
"\n",
|
||||
"encoding_matrix, mask = Net(5,2,1)._weight_pos_enc\n",
|
||||
"print(encoding_matrix, mask)\n",
|
||||
"# view weights of each sublayer in equal chunks, each column representing weights of one selfrepNN\n",
|
||||
"# i.e., first interface*hidden weights of layer1, first hidden*hidden weights of layer2 and first hidden*out weights of layer3 = first net\n",
|
||||
"weights = [layer.view(-1, int(len(layer)/net.nr_nets)) for layer in net.weights]\n",
|
||||
"weights_per_net = [torch.cat([layer[i] for layer in weights]).view(-1,1) for i in range(net.nr_nets)]\n",
|
||||
"\n",
|
||||
"inputs = torch.hstack([encoding_matrix * mask + weights_per_net[i].expand(-1, encoding_matrix.shape[-1]) * (1 - mask) for i in range(net.nr_nets)]) #16, 25\n",
|
||||
"\n",
|
||||
"targets = torch.hstack(weights_per_net)\n",
|
||||
"targets.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from pathlib import Path\n",
|
||||
"import torch\n",
|
||||
"from torch.nn import Flatten\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"from torchvision.datasets import MNIST\n",
|
||||
"from torchvision.transforms import ToTensor, Compose, Resize\n",
|
||||
"from tqdm import tqdm, trange\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"utility_transforms = Compose([ Resize((10, 10)), ToTensor(), Flatten(start_dim=0)])\n",
|
||||
"data_path = Path('data')\n",
|
||||
"WORKER = 8\n",
|
||||
"BATCHSIZE = 10\n",
|
||||
"EPOCH = 1\n",
|
||||
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"\n",
|
||||
"dataset = MNIST(str(data_path), transform=utility_transforms)\n",
|
||||
"d = DataLoader(dataset, batch_size=BATCHSIZE, shuffle=True, drop_last=True, num_workers=WORKER)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def embed_batch(x, repeat_dim):\n",
|
||||
" # x of shape (batchsize, flat_img_dim)\n",
|
||||
" x = x.unsqueeze(-1) #(batchsize, flat_img_dim, 1)\n",
|
||||
" return torch.cat( (torch.zeros( x.shape[0], x.shape[1], 4), x), dim=2).repeat(1,1,repeat_dim) #(batchsize, flat_img_dim, encoding_dim*repeat_dim)\n",
|
||||
"\n",
|
||||
"def embed_vector(x, repeat_dim):\n",
|
||||
" # x of shape [flat_img_dim]\n",
|
||||
" x = x.unsqueeze(-1) #(flat_img_dim, 1)\n",
|
||||
" return torch.cat( (torch.zeros( x.shape[0], 4), x), dim=1).repeat(1,repeat_dim) #(flat_img_dim, encoding_dim*repeat_dim)\n",
|
||||
"\n",
|
||||
"class SparseNetwork():\n",
|
||||
" def __init__(self, input_dim, depth, width, out):\n",
|
||||
" self.input_dim = input_dim\n",
|
||||
" self.depth_dim = depth\n",
|
||||
" self.hidden_dim = width\n",
|
||||
" self.out_dim = out\n",
|
||||
" self.sparse_layers = []\n",
|
||||
" self.sparse_layers.append( SparseLayer( self.input_dim * self.hidden_dim ))\n",
|
||||
" self.sparse_layers.extend([ SparseLayer( self.hidden_dim * self.hidden_dim ) for layer_idx in range(self.depth_dim - 2)])\n",
|
||||
" self.sparse_layers.append( SparseLayer( self.hidden_dim * self.out_dim ))\n",
|
||||
"\n",
|
||||
" def __call__(self, x):\n",
|
||||
" \n",
|
||||
" for sparse_layer in self.sparse_layers[:-1]:\n",
|
||||
" # batch pass (one by one, sparse bmm doesn't support grad)\n",
|
||||
" if len(x.shape) > 1:\n",
|
||||
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
|
||||
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
|
||||
" # vector\n",
|
||||
" else:\n",
|
||||
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
|
||||
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.hidden_dim, x.shape[1]).sum(dim=1)\n",
|
||||
" print(\"out\", x.shape)\n",
|
||||
" \n",
|
||||
" # output layer\n",
|
||||
" sparse_layer = self.sparse_layers[-1]\n",
|
||||
" if len(x.shape) > 1:\n",
|
||||
" embedded_inpt = embed_batch(x, sparse_layer.nr_nets)\n",
|
||||
" x = torch.stack([sparse_layer(inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1) for inpt in embedded_inpt]) #[batchsize, hidden*inpt_dim, feature_dim]\n",
|
||||
" else:\n",
|
||||
" embedded_inpt = embed_vector(x, sparse_layer.nr_nets)\n",
|
||||
" x = sparse_layer(embedded_inpt.T).sum(dim=1).view(self.out_dim, x.shape[1]).sum(dim=1)\n",
|
||||
" print(\"out\", x.shape)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"data_dim = np.prod(dataset[0][0].shape)\n",
|
||||
"metanet = SparseNetwork(data_dim, depth=3, width=5, out=10)\n",
|
||||
"batchx, batchy = next(iter(d))\n",
|
||||
"batchx.shape, batchy.shape\n",
|
||||
"metanet(batchx)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user