Sanity check shape error fix.

This commit is contained in:
Maximilian Zorn
2022-02-21 20:14:07 +01:00
parent f25cee5203
commit f033a2448e

View File

@@ -15,15 +15,15 @@ def extract_weights_from_model(model:MetaNet)->dict:
inpt[-1] = 1
inpt.long()
weights = {i:[] for i in range(len(model._meta_layer_list))}
layers = [layer.particles for layer in model._meta_layer_list]
weights = {i:[] for i in range(model.depth)}
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
for i,layer in enumerate(layers):
for net in layer:
weights[i].append(net(inpt).detach())
return weights
def test_weights_as_model(weights:dict, data):
TransferNet = MetaNetCompareBaseline(model.interface, depth=5, width=6, out=10)
def test_weights_as_model(model, weights:dict, data):
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
with torch.no_grad():
for i, weight_set in weights.items():
TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape))
@@ -56,7 +56,7 @@ if __name__ == '__main__':
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
loss_fn = nn.CrossEntropyLoss()
model = torch.load("mn_st_40_6_res_Tsk_0.85", map_location=DEVICE).eval()
model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval()
weights = extract_weights_from_model(model)
test_weights_as_model(weights, d_test)
test_weights_as_model(model, weights, d_test)