From dd2458da4ae1c0ef7465129a40ae3946a5b862c5 Mon Sep 17 00:00:00 2001
From: Steffen Illium <steffen.illium@ifi.lmu.de>
Date: Mon, 7 Mar 2022 11:29:06 +0100
Subject: [PATCH] adjustments for cuda and auto ckpt cleanup

---
 experiments/meta_task_utility.py | 14 ++++++++++++--
 experiments/robustness_tester.py |  2 +-
 meta_task_exp.py                 | 16 ++++++++--------
 sanity_check_weights.py          | 12 ++++++------
 4 files changed, 27 insertions(+), 17 deletions(-)

diff --git a/experiments/meta_task_utility.py b/experiments/meta_task_utility.py
index f4cb9d4..d1c345e 100644
--- a/experiments/meta_task_utility.py
+++ b/experiments/meta_task_utility.py
@@ -70,7 +70,7 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False):
         ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp'
     else:
         if isinstance(epoch_n, str):
-            ckpt_path = Path(out_path) / f'{epoch_n}_{FINAL_CHECKPOINT_NAME}'
+            ckpt_path = Path(out_path) / f'{Path(FINAL_CHECKPOINT_NAME).stem}_{epoch_n}.tp'
         else:
             ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME
     ckpt_path.parent.mkdir(exist_ok=True, parents=True)
@@ -113,10 +113,20 @@ def new_storage_df(identifier, weight_count):
         return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))])
 
 
-def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, final_model=False,
+def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, keep_n=5, final_model=False,
                             validation_metric=torchmetrics.Accuracy):
     out_path = Path(out_path)
     ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model)
+    # Clean up Checkpoints
+    if keep_n > 0:
+        all_ckpts = sorted(list(ckpt_path.parent.iterdir()))
+        while len(all_ckpts) > keep_n:
+            all_ckpts.pop(0).unlink()
+    elif keep_n == 0:
+        pass
+    else:
+        raise ValueError(f'"keep_n" cannot be negative, but was: {keep_n}')
+
     result = validate(ckpt_path, valid_loader, metric_class=validation_metric)
     return result
 
diff --git a/experiments/robustness_tester.py b/experiments/robustness_tester.py
index 7e78f42..750ea7e 100644
--- a/experiments/robustness_tester.py
+++ b/experiments/robustness_tester.py
@@ -77,7 +77,7 @@ def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10):
                             # When this raises a Type Error, we found a second order fixpoint!
                         steps += 1
 
-                        df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
+                        df.loc[df.shape[0]] = [f'{setting}_{seed}', fr'$\mathregular{{10^{{-{noise_level}}}}}$',
                                                steps, absolute_loss,
                                                time_to_vergence[setting][noise_level],
                                                time_as_fixpoint[setting][noise_level]]
diff --git a/meta_task_exp.py b/meta_task_exp.py
index 279f5c3..42283b8 100644
--- a/meta_task_exp.py
+++ b/meta_task_exp.py
@@ -36,7 +36,7 @@ else:
 from network import MetaNet, FixTypes
 from functionalities_test import test_for_fixpoints
 
-utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0), AddGaussianNoise()])
+utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)])  # , AddGaussianNoise()])
 WORKER = 10 if not debug else 2
 debug = False
 BATCHSIZE = 2000 if not debug else 50
@@ -60,16 +60,16 @@ plot_loader = DataLoader(plot_dataset, batch_size=BATCHSIZE, shuffle=False,
 
 if __name__ == '__main__':
 
-    training = False
-    plotting = False
-    robustnes = True    # EXPENSIV!!!!!!!
-    n_st = 300          # per batch !!
+    training = True
+    plotting = True
+    robustnes = True
+    n_st = 1            # per batch !!
     activation = None   # nn.ReLU()
 
     for weight_hidden_size in [3]:
 
         weight_hidden_size = weight_hidden_size
-        residual_skip = True
+        residual_skip = False
         n_seeds = 3
         depth = 5
         width = 3
@@ -84,7 +84,7 @@ if __name__ == '__main__':
         st_str = f'_nst_{n_st}'
 
         config_str = f'{res_str}{ac_str}{st_str}'
-        exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}_gauss'
+        exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}'
 
         for seed in range(n_seeds):
             seed_path = exp_path / str(seed)
@@ -161,7 +161,7 @@ if __name__ == '__main__':
                         except RuntimeError:
                             pass
 
-                        accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch).item()
+                        accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch, keep_n=5).item()
                         validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE,
                                               Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy)
                         train_store.loc[train_store.shape[0]] = validation_log
diff --git a/sanity_check_weights.py b/sanity_check_weights.py
index 4caa5ef..834bdbf 100644
--- a/sanity_check_weights.py
+++ b/sanity_check_weights.py
@@ -15,13 +15,12 @@ from functionalities_test import epsilon_error_margin as e
 from network import MetaNet, MetaNetCompareBaseline
 
 
-def extract_weights_from_model(model:MetaNet)->dict:
-    inpt = torch.zeros(5)
+def extract_weights_from_model(model: MetaNet) -> dict:
+    inpt = torch.zeros(5, device=next(model.parameters()).device, dtype=torch.float)
     inpt[-1] = 1
-    inpt.long()
 
     weights = defaultdict(list)
-    layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]]
+    layers = [layer.particles for layer in model.all_layers]
     for i, layer in enumerate(layers):
         for net in layer:
             weights[i].append(net(inpt).detach())
@@ -29,9 +28,10 @@ def extract_weights_from_model(model:MetaNet)->dict:
 
 
 def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics.Accuracy):
+    meta_net_device = next(meta_net.parameters()).device
     transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth,
                                           width=meta_net.width, out=meta_net.out,
-                                          residual_skip=meta_net.residual_skip)
+                                          residual_skip=meta_net.residual_skip).to(meta_net_device)
     with torch.no_grad():
         new_weight_values = list(new_weights.values())
         old_parameters = list(transfer_net.parameters())
@@ -45,7 +45,7 @@ def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics
         net.eval()
         metric = metric_class()
         for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='Test Batch: '):
-            y = net(batch_x)
+            y = net(batch_x.to(meta_net_device))
             metric(y.cpu(), batch_y.cpu())
 
         # metric on all batches using custom accumulation