diff --git a/network.py b/network.py
index b09a9fc..eed4bda 100644
--- a/network.py
+++ b/network.py
@@ -487,25 +487,33 @@ class MetaNet(nn.Module):
 
 class MetaNetCompareBaseline(nn.Module):
 
-    def __init__(self, interface=4, depth=3, width=4, out=1, activation=None):
+    def __init__(self, interface=4, depth=3, width=4, out=1, activation=None, residual_skip=True):
         super().__init__()
+        self.residual_skip = residual_skip
         self.activation = activation
         self.out = out
         self.interface = interface
         self.width = width
         self.depth = depth
-
-        self._meta_layer_list = nn.ModuleList()
-
-        self._meta_layer_list.append(nn.Linear(self.interface, self.width, bias=False))
-        self._meta_layer_list.extend([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
-        self._meta_layer_list.append(nn.Linear(self.width, self.out, bias=False))
+        
+        self._first_layer = nn.Linear(self.interface, self.width, bias=False)
+        self._meta_layer_list = nn.ModuleList([nn.Linear(self.width, self.width, bias=False) for _ in range(self.depth - 2)])
+        self._last_layer = nn.Linear(self.width, self.out, bias=False)
 
     def forward(self, x):
-        tensor = x
-        for meta_layer in self._meta_layer_list:
+        tensor = self._first_layer(x)
+        for idx, meta_layer in enumerate(self._meta_layer_list, start=1):
+            if idx % 2 == 1 and self.residual_skip:
+                x = tensor.clone()
             tensor = meta_layer(tensor)
+            if idx % 2 == 0 and self.residual_skip:
+                tensor = tensor + x
+        tensor = self._last_layer(tensor)
         return tensor
+    
+    @property
+    def all_layers(self):
+        return (x for x in (self._first_layer, *self._meta_layer_list, self._last_layer))
 
 
 if __name__ == '__main__':
diff --git a/sanity_check_weights.py b/sanity_check_weights.py
index d46457f..e6449f2 100644
--- a/sanity_check_weights.py
+++ b/sanity_check_weights.py
@@ -26,7 +26,8 @@ def extract_weights_from_model(model:MetaNet)->dict:
 
 
 def test_weights_as_model(model, new_weights:dict, data):
-    TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out)
+    TransferNet = MetaNetCompareBaseline(model.interface, depth=model.depth, width=model.width, out=model.out,
+                                         residual_skip=True)
 
     with torch.no_grad():
         for weights, parameters in zip(new_weights.values(), TransferNet.parameters()):
@@ -37,7 +38,6 @@ def test_weights_as_model(model, new_weights:dict, data):
     with tqdm(desc='Test Batch: ') as pbar:
         for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='MetaNet Sanity Check'):
             y = TransferNet(batch_x)
-            loss = loss_fn(y, batch_y)
             acc = metric(y.cpu(), batch_y.cpu())
             pbar.set_postfix_str(f'Acc: {acc}')
             pbar.update()
@@ -52,13 +52,12 @@ if __name__ == '__main__':
     DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     WORKER = 0
     BATCHSIZE = 500
-    MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Normalize((0.1307,), (0.3081,)), Flatten(start_dim=0)])
+    MNIST_TRANSFORM = Compose([Resize((15, 15)), ToTensor(), Flatten(start_dim=0)])
     torch.manual_seed(42)
     data_path = Path('data')
     data_path.mkdir(exist_ok=True, parents=True)
     mnist_test = MNIST(str(data_path), transform=MNIST_TRANSFORM, download=True, train=False)
     d_test = DataLoader(mnist_test, batch_size=BATCHSIZE, shuffle=False, drop_last=True, num_workers=WORKER)
-    loss_fn = nn.CrossEntropyLoss()
     
     model = torch.load(Path('experiments/output/trained_model_ckpt_e50.tp'), map_location=DEVICE).eval()
     weights = extract_weights_from_model(model)