diff --git a/sanity_check_weights.py b/sanity_check_weights.py index 32f789b..7fc8f18 100644 --- a/sanity_check_weights.py +++ b/sanity_check_weights.py @@ -15,15 +15,15 @@ def extract_weights_from_model(model:MetaNet)->dict: inpt[-1] = 1 inpt.long() - weights = {i:[] for i in range(len(model._meta_layer_list))} - layers = [layer.particles for layer in model._meta_layer_list] + weights = {i:[] for i in range(model.depth)} + layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]] for i,layer in enumerate(layers): for net in layer: weights[i].append(net(inpt).detach()) return weights -def test_weights_as_model(weights:dict, data): - TransferNet = MetaNetCompareBaseline(model.interface, depth=5, width=6, out=10) +def test_weights_as_model(model, weights:dict, data): + TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out) with torch.no_grad(): for i, weight_set in weights.items(): TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape)) @@ -55,8 +55,8 @@ if __name__ == '__main__': mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False) d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) loss_fn = nn.CrossEntropyLoss() - model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp') - model = torch.load(model_path, map_location=DEVICE).eval() - weights = extract_weights_from_model(model) - test_weights_as_model(weights, d_test) + + model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval() + weights = extract_weights_from_model(model) + test_weights_as_model(model, weights, d_test)