Merge remote-tracking branch 'origin/swarm' into swarm

# Conflicts:
#	sanity_check_weights.py
This commit is contained in:
Steffen Illium
2022-02-22 09:55:10 +01:00

View File

@@ -15,15 +15,15 @@ def extract_weights_from_model(model:MetaNet)->dict:
inpt[-1] = 1
inpt.long()
weights = {i:[] for i in range(len(model._meta_layer_list))}
layers = [layer.particles for layer in model._meta_layer_list]
weights = {i:[] for i in range(model.depth)}
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
for i,layer in enumerate(layers):
for net in layer:
weights[i].append(net(inpt).detach())
return weights
def test_weights_as_model(weights:dict, data):
TransferNet = MetaNetCompareBaseline(model.interface, depth=5, width=6, out=10)
def test_weights_as_model(model, weights:dict, data):
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
with torch.no_grad():
for i, weight_set in weights.items():
TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape))
@@ -55,8 +55,8 @@ if __name__ == '__main__':
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
loss_fn = nn.CrossEntropyLoss()
model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp')
model = torch.load(model_path, map_location=DEVICE).eval()
weights = extract_weights_from_model(model)
test_weights_as_model(weights, d_test)
model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval()
weights = extract_weights_from_model(model)
test_weights_as_model(model, weights, d_test)