In [222]:
from network import Net
import torch
from typing import List
from functionalities_test import is_identity_function

In [255]:
nr_nets = 5
nets = [Net(4,2,1) for _ in range(nr_nets)]

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)

In [256]:
sum([is_identity_function(net) for net in nets])

0

In [247]:
X = torch.hstack( [net.input_weight_matrix() for net in nets] ) #(nr_nets*nr_weights, nr_weights)
Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ) #(nr_nets*nr_weights,1)
X.shape, Y.shape

(torch.Size([14, 20]), torch.Size([14, 5]))

In [270]:
def construct_sparse_tensor_layer(nets:List[Net], layer_idx:int) -> torch.Tensor:
    assert layer_idx <= len(list(nets[0].parameters()))
    values = []
    indices = []
    for net_idx,net in enumerate(nets):
        layer = list(net.parameters())[layer_idx]
        
        for cell_idx,cell in enumerate(layer):
            # E.g., position of cell weights (with 2 cells per hidden layer) in first sparse layer of N nets: 
            # [4x2 weights_net0] [4x2x(n-1) 0s]
            # [4x2 weights] [4x2 weights_net0] [4x2x(n-2) 0s]
            # ... etc
            # [4x2x(n-1) 0s] [4x2 weights_netN]
            # -> 4x2 weights on the diagonal = [shifted Nr_cellss*B down for AxB cells, and Nr_nets(*A weights)to the right] 
            for i in range(len(cell)):
                indices.append([len(layer)*net_idx + cell_idx,    net_idx*len(cell) + i ])
                #indices.append([2*net_idx + cell_idx,    net_idx*len(cell) + i ])

            [values.append(weight) for weight in cell]
            # for i in range(4):
            #     indices.append([idx+idx+1,  i+(idx*4)])
            #for l in next(net.parameters()):
            #[values.append(w) for w in l]
    #print(indices, values)

    #s = torch.sparse_coo_tensor(list(zip(*indices)), values, (2*nr_nets, 4*nr_nets))
    # sparse tensor dimension = (nr_cells*nr_nets , nr_weights/cell * nr_nets), i.e.,
    # layer 1: (2x4) -> (2*N, 4*N)
    # layer 2: (2x2) -> (2*N, 2*N)
    # layer 3: (1x2) -> (2*N, 1*N)
    s = torch.sparse_coo_tensor(list(zip(*indices)), values, (len(layer)*nr_nets, len(cell)*nr_nets),requires_grad=True)
    #print(s.to_dense())
    #print(s.to_dense().shape)
    return s


# for each net append to the combined sparse tensor
# construct sparse tensor for each layer, with Nets of (4,2,1), each net appends
# - [4x2] weights in the first (input) layer
# - [2x2] weights in the second (hidden) layer
# - [2x1] weights in the third (output) layer
modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]
modules
#for layer_idx in range(len(list(nets[0].parameters()))):
#    sparse_tensor = construct_sparse_tensor_layer(nets, layer_idx)

[tensor(indices=tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,
                          3,  3,  4,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,
                          7,  7,  7,  7,  8,  8,  8,  8,  9,  9,  9,  9],
                        [ 0,  1,  2,  3,  0,  1,  2,  3,  4,  5,  6,  7,  4,  5,
                          6,  7,  8,  9, 10, 11,  8,  9, 10, 11, 12, 13, 14, 15,
                         12, 13, 14, 15, 16, 17, 18, 19, 16, 17, 18, 19]]),
        values=tensor([-0.0282,  0.1612,  0.0717, -0.1370, -0.0789,  0.0990,
                       -0.0642,  0.1385, -0.1046,  0.1522, -0.0691,  0.0848,
                       -0.1419, -0.0465, -0.0385,  0.1453, -0.0263,  0.1401,
                        0.0758,  0.1022,  0.1218, -0.1423,  0.0556,  0.0150,
                        0.0598, -0.0347, -0.0717,  0.1173,  0.0126, -0.0164,
                       -0.0359,  0.0895,  0.1545, -0.1091,  0.0925,  0.0687,
                        0.1330,  0.1297,  0.0305,  0.1811]),
   

In [295]:
nr_nets = 5000
nets = [Net(4,2,1) for _ in range(nr_nets)]
print(f"before: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns")

loss_fn = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD([param for net in nets for param in net.parameters()], lr=0.004, momentum=0.9)


for train_iteration in range(1):
    optimizer.zero_grad()  
    X = torch.hstack( [net.input_weight_matrix() for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights, nr_weights)
    Y = torch.hstack( [net.create_target_weights(net.input_weight_matrix()) for net in nets] ).requires_grad_(True).T #(nr_nets*nr_weights,1)
    #print("X ", X.shape, "Y", Y.shape)

    modules = [ construct_sparse_tensor_layer(nets, layer_idx) for layer_idx in range(len(list(nets[0].parameters()))) ]

    X1 = torch.sparse.mm(modules[0], X)
    #print("X1", X1.shape, X1)

    X2 = torch.sparse.mm(modules[1], X1)
    #print("X2", X2.shape)

    X3 = torch.sparse.mm(modules[2], X2)
    #print("X3", X3.shape)

    loss = loss_fn(X3, Y)
    #print(loss)
    loss.backward()
    optimizer.step()

print(f"after {train_iteration+1} iterations of combined self_train: {sum([is_identity_function(net) for net in nets])}/{len(nets)} identity_fns")

before: 0/5000 identity_fns
after 1 iterations of combined self_train: 0/5000 identity_fns
