93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
from collections import defaultdict
|
|
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import Flatten
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision.datasets import MNIST, CIFAR10
|
|
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, Grayscale
|
|
import torchmetrics
|
|
|
|
from functionalities_test import epsilon_error_margin as e
|
|
from network import MetaNet, MetaNetCompareBaseline
|
|
|
|
|
|
def extract_weights_from_model(model:MetaNet)->dict:
|
|
inpt = torch.zeros(5)
|
|
inpt[-1] = 1
|
|
inpt.long()
|
|
|
|
weights = defaultdict(list)
|
|
layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
|
|
for i, layer in enumerate(layers):
|
|
for net in layer:
|
|
weights[i].append(net(inpt).detach())
|
|
return dict(weights)
|
|
|
|
|
|
def test_weights_as_model(meta_net, new_weights:dict, data):
|
|
transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth, width=meta_net.width, out=meta_net.out,
|
|
residual_skip=True)
|
|
with torch.no_grad():
|
|
new_weight_values = list(new_weights.values())
|
|
old_parameters = list(transfer_net.parameters())
|
|
assert len(new_weight_values) == len(old_parameters)
|
|
for weights, parameters in zip(new_weights.values(), transfer_net.parameters()):
|
|
parameters[:] = torch.Tensor(weights).view(parameters.shape)[:]
|
|
|
|
transfer_net.eval()
|
|
|
|
# Test if the margin of error is similar
|
|
|
|
im_t = defaultdict(list)
|
|
rand = torch.randn((1, 15 * 15))
|
|
for net in [meta_net, transfer_net]:
|
|
tensor = rand.clone()
|
|
for layer in net.all_layers:
|
|
tensor = layer(tensor)
|
|
im_t[net.__class__.__name__].append(tensor.detach())
|
|
|
|
im_t = dict(im_t)
|
|
|
|
all_close = {f'layer_{idx}': torch.allclose(y1.detach(), y2.detach(), rtol=0, atol=e
|
|
) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
|
|
}
|
|
print(f'Cummulative differences per layer is smaller then {e}:\n {all_close}')
|
|
# all_errors = {f'layer_{idx}': torch.absolute(y1.detach(), y2.detach(), rtol=0, atol=e
|
|
# ) for idx, (y1, y2) in enumerate(zip(*im_t.values()))
|
|
# }
|
|
|
|
for net in [meta_net, transfer_net]:
|
|
net.eval()
|
|
metric = torchmetrics.Accuracy()
|
|
with tqdm(desc='Test Batch: ') as pbar:
|
|
for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
|
|
y = net(batch_x)
|
|
acc = metric(y.cpu(), batch_y.cpu())
|
|
pbar.set_postfix_str(f'Acc: {acc}')
|
|
pbar.update()
|
|
|
|
# metric on all batches using custom accumulation
|
|
acc = metric.compute()
|
|
tqdm.write(f"Avg. accuracy on {net.__class__.__name__}: {acc}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
WORKER = 0
|
|
BATCHSIZE = 500
|
|
MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Flatten(start_dim=0)])
|
|
torch.manual_seed(42)
|
|
data_path = Path('data')
|
|
data_path.mkdir(exist_ok=True, parents=True)
|
|
mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
|
|
d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
|
|
|
|
model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
|
|
weights = extract_weights_from_model(model)
|
|
test_weights_as_model(model, weights, d_test)
|
|
|