From f033a2448e1196c168776bdfeaaef72e0d3084eb Mon Sep 17 00:00:00 2001 From: Maximilian Zorn Date: Mon, 21 Feb 2022 20:14:07 +0100 Subject: [PATCH] Sanity check shape error fix. --- sanity_check_weights.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sanity_check_weights.py b/sanity_check_weights.py index 9610f1b..7fc8f18 100644 --- a/sanity_check_weights.py +++ b/sanity_check_weights.py @@ -15,15 +15,15 @@ def extract_weights_from_model(model:MetaNet)->dict: inpt[-1] = 1 inpt.long() - weights = {i:[] for i in range(len(model._meta_layer_list))} - layers = [layer.particles for layer in model._meta_layer_list] + weights = {i:[] for i in range(model.depth)} + layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]] for i,layer in enumerate(layers): for net in layer: weights[i].append(net(inpt).detach()) return weights -def test_weights_as_model(weights:dict, data): - TransferNet = MetaNetCompareBaseline(model.interface, depth=5, width=6, out=10) +def test_weights_as_model(model, weights:dict, data): + TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out) with torch.no_grad(): for i, weight_set in weights.items(): TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape)) @@ -56,7 +56,7 @@ if __name__ == '__main__': d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) loss_fn = nn.CrossEntropyLoss() - model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval() + model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval() weights = extract_weights_from_model(model) - test_weights_as_model(weights, d_test) + test_weights_as_model(model, weights, d_test)