Merge remote-tracking branch 'origin/swarm' into swarm
# Conflicts: # sanity_check_weights.py
This commit is contained in:
@@ -15,15 +15,15 @@ def extract_weights_from_model(model:MetaNet)->dict:
|
||||
inpt[-1] = 1
|
||||
inpt.long()
|
||||
|
||||
weights = {i:[] for i in range(len(model._meta_layer_list))}
|
||||
layers = [layer.particles for layer in model._meta_layer_list]
|
||||
weights = {i:[] for i in range(model.depth)}
|
||||
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
|
||||
for i,layer in enumerate(layers):
|
||||
for net in layer:
|
||||
weights[i].append(net(inpt).detach())
|
||||
return weights
|
||||
|
||||
def test_weights_as_model(weights:dict, data):
|
||||
TransferNet = MetaNetCompareBaseline(model.interface, depth=5, width=6, out=10)
|
||||
def test_weights_as_model(model, weights:dict, data):
|
||||
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
|
||||
with torch.no_grad():
|
||||
for i, weight_set in weights.items():
|
||||
TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape))
|
||||
@@ -55,8 +55,8 @@ if __name__ == '__main__':
|
||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp')
|
||||
model = torch.load(model_path, map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
test_weights_as_model(weights, d_test)
|
||||
|
||||
model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval()
|
||||
weights = extract_weights_from_model(model)
|
||||
test_weights_as_model(model, weights, d_test)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user