From 5b2b5b5beb83925ac72c15de8ede92fcd45e4f0a Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Wed, 23 Feb 2022 18:32:36 +0100 Subject: [PATCH] residual skip in metacomparebaseline --- network.py | 26 +++++++++++++++++--------- sanity_check_weights.py | 7 +++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/network.py b/network.py index b09a9fc..eed4bda 100644 --- a/network.py +++ b/network.py @@ -487,25 +487,33 @@ class MetaNet(nn.Module): class MetaNetCompareBaseline(nn.Module): - def __init__(self, interface=4, depth=3, width=4, out=1, activation=None): + def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True): super().__init__() + self.residual_skip = residual_skip self.activation = activation self.out = out self.interface = interface self.width = width self.depth = depth - - self._meta_layer_list = nn.ModuleList() - - self._meta_layer_list.append(nn.Linear(self.interface, self.width, bias=False)) - self._meta_layer_list.extend([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)]) - self._meta_layer_list.append(nn.Linear(self.width, self.out, bias=False)) + + self._first_layer = nn.Linear(self.interface, self.width, bias=False) + self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)]) + self._last_layer = nn.Linear(self.width, self.out, bias=False) def forward(self, x): - tensor = x - for meta_layer in self._meta_layer_list: + tensor = self._first_layer(x) + for idx, meta_layer in enumerate(self._meta_layer_list, start=1): + if idx % 2 == 1 and self.residual_skip: + x = tensor.clone() tensor = meta_layer(tensor) + if idx % 2 == 0 and self.residual_skip: + tensor = tensor + x + tensor = self._last_layer(tensor) return tensor + + @property + def all_layers(self): + return (x for x in (self._first_layer, *self._meta_layer_list, self._last_layer)) if __name__ == '__main__': diff --git a/sanity_check_weights.py b/sanity_check_weights.py index d46457f..e6449f2 100644 --- a/sanity_check_weights.py +++ b/sanity_check_weights.py @@ -26,7 +26,8 @@ def extract_weights_from_model(model:MetaNet)->dict: def test_weights_as_model(model, new_weights:dict, data): - TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out) + TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out, + residual_skip=True) with torch.no_grad(): for weights, parameters in zip(new_weights.values(), TransferNet.parameters()): @@ -37,7 +38,6 @@ def test_weights_as_model(model, new_weights:dict, data): with tqdm(desc='Test Batch: ') as pbar: for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'): y = TransferNet(batch_x) - loss = loss_fn(y, batch_y) acc = metric(y.cpu(), batch_y.cpu()) pbar.set_postfix_str(f'Acc: {acc}') pbar.update() @@ -52,13 +52,12 @@ if __name__ == '__main__': DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') WORKER = 0 BATCHSIZE = 500 - MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)]) + MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Flatten(start_dim=0)]) torch.manual_seed(42) data_path = Path('data') data_path.mkdir(exist_ok=True, parents=True) mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False) d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER) - loss_fn = nn.CrossEntropyLoss() model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval() weights = extract_weights_from_model(model)