243 lines
9.1 KiB
Plaintext
243 lines
9.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 222,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from network import Net\n",
|
|
"import torch\n",
|
|
"from typing import List\n",
|
|
"from functionalities_test import is_identity_function"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 255,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"nr_nets = 5\n",
|
|
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
|
"\n",
|
|
"loss_fn = torch.nn.MSELoss()\n",
|
|
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 256,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0"
|
|
]
|
|
},
|
|
"execution_count": 256,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sum([is_identity_function(net) for net in nets])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 247,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(torch.Size([14, 20]), torch.Size([14, 5]))"
|
|
]
|
|
},
|
|
"execution_count": 247,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"X = torch.hstack( [net.input_weight_matrix() for net in nets] ) #(nr_nets*nr_weights, nr_weights)\n",
|
|
"Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ) #(nr_nets*nr_weights,1)\n",
|
|
"X.shape, Y.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 270,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[tensor(indices=tensor([[ 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3,\n",
|
|
" 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6,\n",
|
|
" 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9],\n",
|
|
" [ 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 4, 5,\n",
|
|
" 6, 7, 8, 9, 10, 11, 8, 9, 10, 11, 12, 13, 14, 15,\n",
|
|
" 12, 13, 14, 15, 16, 17, 18, 19, 16, 17, 18, 19]]),\n",
|
|
" values=tensor([-0.0282, 0.1612, 0.0717, -0.1370, -0.0789, 0.0990,\n",
|
|
" -0.0642, 0.1385, -0.1046, 0.1522, -0.0691, 0.0848,\n",
|
|
" -0.1419, -0.0465, -0.0385, 0.1453, -0.0263, 0.1401,\n",
|
|
" 0.0758, 0.1022, 0.1218, -0.1423, 0.0556, 0.0150,\n",
|
|
" 0.0598, -0.0347, -0.0717, 0.1173, 0.0126, -0.0164,\n",
|
|
" -0.0359, 0.0895, 0.1545, -0.1091, 0.0925, 0.0687,\n",
|
|
" 0.1330, 0.1297, 0.0305, 0.1811]),\n",
|
|
" size=(10, 20), nnz=40, layout=torch.sparse_coo, requires_grad=True),\n",
|
|
" tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9,\n",
|
|
" 9],\n",
|
|
" [0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9, 8,\n",
|
|
" 9]]),\n",
|
|
" values=tensor([-0.1608, 0.0952, 0.0369, 0.0105, -0.0277, 0.0216,\n",
|
|
" 0.0991, 0.1250, 0.0618, 0.2241, 0.0602, 0.1144,\n",
|
|
" -0.0330, -0.1240, 0.0456, -0.1208, -0.1859, 0.1333,\n",
|
|
" 0.1235, -0.1774]),\n",
|
|
" size=(10, 10), nnz=20, layout=torch.sparse_coo, requires_grad=True),\n",
|
|
" tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4],\n",
|
|
" [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]),\n",
|
|
" values=tensor([ 0.0585, 0.1856, -0.0987, 0.2342, -0.0376, 0.0765,\n",
|
|
" -0.1395, 0.1574, -0.0103, -0.0933]),\n",
|
|
" size=(5, 10), nnz=10, layout=torch.sparse_coo, requires_grad=True)]"
|
|
]
|
|
},
|
|
"execution_count": 270,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"def construct_sparse_tensor_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:\n",
|
|
" assert layer_idx <= len(list(nets[0].parameters()))\n",
|
|
" values = []\n",
|
|
" indices = []\n",
|
|
" for net_idx,net in enumerate(nets):\n",
|
|
" layer = list(net.parameters())[layer_idx]\n",
|
|
" \n",
|
|
" for cell_idx,cell in enumerate(layer):\n",
|
|
" # E.g., position of cell weights (with 2 cells per hidden layer) in first sparse layer of N nets: \n",
|
|
" # [4x2 weights_net0] [4x2x(n-1) 0s]\n",
|
|
" # [4x2 weights] [4x2 weights_net0] [4x2x(n-2) 0s]\n",
|
|
" # ... etc\n",
|
|
" # [4x2x(n-1) 0s] [4x2 weights_netN]\n",
|
|
" # -> 4x2 weights on the diagonal = [shifted Nr_cellss*B down for AxB cells, and Nr_nets(*A weights)to the right] \n",
|
|
" for i in range(len(cell)):\n",
|
|
" indices.append([len(layer)*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
|
" #indices.append([2*net_idx + cell_idx, net_idx*len(cell) + i ])\n",
|
|
"\n",
|
|
" [values.append(weight) for weight in cell]\n",
|
|
" # for i in range(4):\n",
|
|
" # indices.append([idx+idx+1, i+(idx*4)])\n",
|
|
" #for l in next(net.parameters()):\n",
|
|
" #[values.append(w) for w in l]\n",
|
|
" #print(indices, values)\n",
|
|
"\n",
|
|
" #s = torch.sparse_coo_tensor(list(zip(*indices)), values, (2*nr_nets, 4*nr_nets))\n",
|
|
" # sparse tensor dimension = (nr_cells*nr_nets , nr_weights/cell * nr_nets), i.e.,\n",
|
|
" # layer 1: (2x4) -> (2*N, 4*N)\n",
|
|
" # layer 2: (2x2) -> (2*N, 2*N)\n",
|
|
" # layer 3: (1x2) -> (2*N, 1*N)\n",
|
|
" s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets),requires_grad=True)\n",
|
|
" #print(s.to_dense())\n",
|
|
" #print(s.to_dense().shape)\n",
|
|
" return s\n",
|
|
"\n",
|
|
"\n",
|
|
"# for each net append to the combined sparse tensor\n",
|
|
"# construct sparse tensor for each layer, with Nets of (4,2,1), each net appends\n",
|
|
"# - [4x2] weights in the first (input) layer\n",
|
|
"# - [2x2] weights in the second (hidden) layer\n",
|
|
"# - [2x1] weights in the third (output) layer\n",
|
|
"modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
|
"modules\n",
|
|
"#for layer_idx in range(len(list(nets[0].parameters()))):\n",
|
|
"# sparse_tensor = construct_sparse_tensor_layer(nets, layer_idx)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 295,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"before: 0/5000 identity_fns\n",
|
|
"after 1 iterations of combined self_train: 0/5000 identity_fns\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"nr_nets = 5000\n",
|
|
"nets = [Net(4,2,1) for _ in range(nr_nets)]\n",
|
|
"print(f\"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")\n",
|
|
"\n",
|
|
"loss_fn = torch.nn.MSELoss(reduction=\"sum\")\n",
|
|
"optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)\n",
|
|
"\n",
|
|
"\n",
|
|
"for train_iteration in range(1):\n",
|
|
" optimizer.zero_grad() \n",
|
|
" X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)\n",
|
|
" Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)\n",
|
|
" #print(\"X \", X.shape, \"Y\", Y.shape)\n",
|
|
"\n",
|
|
" modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]\n",
|
|
"\n",
|
|
" X1 = torch.sparse.mm(modules[0], X)\n",
|
|
" #print(\"X1\", X1.shape, X1)\n",
|
|
"\n",
|
|
" X2 = torch.sparse.mm(modules[1], X1)\n",
|
|
" #print(\"X2\", X2.shape)\n",
|
|
"\n",
|
|
" X3 = torch.sparse.mm(modules[2], X2)\n",
|
|
" #print(\"X3\", X3.shape)\n",
|
|
"\n",
|
|
" loss = loss_fn(X3, Y)\n",
|
|
" #print(loss)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
"print(f\"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"interpreter": {
|
|
"hash": "8bcba732c17ca4dacffea8ad1176c852d4229b36b9060a5f633fff752e5396ea"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3.8.12 64-bit ('masterthesis': conda)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.12"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|