Merge remote-tracking branch 'origin/swarm' into swarm
# Conflicts: # sanity_check_weights.py
This commit is contained in:
@@ -15,15 +15,15 @@ def extract_weights_from_model(model:MetaNet)->dict:
|
|||||||
inpt[-1] = 1
|
inpt[-1] = 1
|
||||||
inpt.long()
|
inpt.long()
|
||||||
|
|
||||||
weights = {i:[] for i in range(len(model._meta_layer_list))}
|
weights = {i:[] for i in range(model.depth)}
|
||||||
layers = [layer.particles for layer in model._meta_layer_list]
|
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
|
||||||
for i,layer in enumerate(layers):
|
for i,layer in enumerate(layers):
|
||||||
for net in layer:
|
for net in layer:
|
||||||
weights[i].append(net(inpt).detach())
|
weights[i].append(net(inpt).detach())
|
||||||
return weights
|
return weights
|
||||||
|
|
||||||
def test_weights_as_model(weights:dict, data):
|
def test_weights_as_model(model, weights:dict, data):
|
||||||
TransferNet = MetaNetCompareBaseline(model.interface, depth=5, width=6, out=10)
|
TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, weight_set in weights.items():
|
for i, weight_set in weights.items():
|
||||||
TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape))
|
TransferNet._meta_layer_list[i].weight = torch.nn.Parameter(torch.tensor(weight_set).view(list(TransferNet.parameters())[i].shape))
|
||||||
@@ -55,8 +55,8 @@ if __name__ == '__main__':
|
|||||||
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
||||||
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
model_path = (Path() / r'experiments\output\mn_st_40_6_res_Tsk_0.85\trained_model_ckpt_e40.tp')
|
|
||||||
model = torch.load(model_path, map_location=DEVICE).eval()
|
model = torch.load("0039_model_ckpt.tp", map_location=DEVICE).eval()
|
||||||
weights = extract_weights_from_model(model)
|
weights = extract_weights_from_model(model)
|
||||||
test_weights_as_model(weights, d_test)
|
test_weights_as_model(model, weights, d_test)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user