From dd2458da4ae1c0ef7465129a40ae3946a5b862c5 Mon Sep 17 00:00:00 2001 From: Steffen Illium <steffen.illium@ifi.lmu.de> Date: Mon, 7 Mar 2022 11:29:06 +0100 Subject: [PATCH] adjustments for cuda and auto ckpt cleanup --- experiments/meta_task_utility.py | 14 ++++++++++++-- experiments/robustness_tester.py | 2 +- meta_task_exp.py | 16 ++++++++-------- sanity_check_weights.py | 12 ++++++------ 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/experiments/meta_task_utility.py b/experiments/meta_task_utility.py index f4cb9d4..d1c345e 100644 --- a/experiments/meta_task_utility.py +++ b/experiments/meta_task_utility.py @@ -70,7 +70,7 @@ def set_checkpoint(model, out_path, epoch_n, final_model=False): ckpt_path = Path(out_path) / 'ckpt' / f'{epoch_n.zfill(4)}_model_ckpt.tp' else: if isinstance(epoch_n, str): - ckpt_path = Path(out_path) / f'{epoch_n}_{FINAL_CHECKPOINT_NAME}' + ckpt_path = Path(out_path) / f'{Path(FINAL_CHECKPOINT_NAME).stem}_{epoch_n}.tp' else: ckpt_path = Path(out_path) / FINAL_CHECKPOINT_NAME ckpt_path.parent.mkdir(exist_ok=True, parents=True) @@ -113,10 +113,20 @@ def new_storage_df(identifier, weight_count): return pd.DataFrame(columns=['Epoch', 'Weight', *(f'weight_{x}' for x in range(weight_count))]) -def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, final_model=False, +def checkpoint_and_validate(model, valid_loader, out_path, epoch_n, keep_n=5, final_model=False, validation_metric=torchmetrics.Accuracy): out_path = Path(out_path) ckpt_path = set_checkpoint(model, out_path, epoch_n, final_model=final_model) + # Clean up Checkpoints + if keep_n > 0: + all_ckpts = sorted(list(ckpt_path.parent.iterdir())) + while len(all_ckpts) > keep_n: + all_ckpts.pop(0).unlink() + elif keep_n == 0: + pass + else: + raise ValueError(f'"keep_n" cannot be negative, but was: {keep_n}') + result = validate(ckpt_path, valid_loader, metric_class=validation_metric) return result diff --git a/experiments/robustness_tester.py b/experiments/robustness_tester.py index 7e78f42..750ea7e 100644 --- a/experiments/robustness_tester.py +++ b/experiments/robustness_tester.py @@ -77,7 +77,7 @@ def test_robustness(model_path, noise_levels=10, seeds=10, log_step_size=10): # When this raises a Type Error, we found a second order fixpoint! steps += 1 - df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$', + df.loc[df.shape[0]] = [f'{setting}_{seed}', fr'$\mathregular{{10^{{-{noise_level}}}}}$', steps, absolute_loss, time_to_vergence[setting][noise_level], time_as_fixpoint[setting][noise_level]] diff --git a/meta_task_exp.py b/meta_task_exp.py index 279f5c3..42283b8 100644 --- a/meta_task_exp.py +++ b/meta_task_exp.py @@ -36,7 +36,7 @@ else: from network import MetaNet, FixTypes from functionalities_test import test_for_fixpoints -utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0), AddGaussianNoise()]) +utility_transforms = Compose([ToTensor(), ToFloat(), Resize((15, 15)), Flatten(start_dim=0)]) # , AddGaussianNoise()]) WORKER = 10 if not debug else 2 debug = False BATCHSIZE = 2000 if not debug else 50 @@ -60,16 +60,16 @@ plot_loader = DataLoader(plot_dataset, batch_size=BATCHSIZE, shuffle=False, if __name__ == '__main__': - training = False - plotting = False - robustnes = True # EXPENSIV!!!!!!! - n_st = 300 # per batch !! + training = True + plotting = True + robustnes = True + n_st = 1 # per batch !! activation = None # nn.ReLU() for weight_hidden_size in [3]: weight_hidden_size = weight_hidden_size - residual_skip = True + residual_skip = False n_seeds = 3 depth = 5 width = 3 @@ -84,7 +84,7 @@ if __name__ == '__main__': st_str = f'_nst_{n_st}' config_str = f'{res_str}{ac_str}{st_str}' - exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}_gauss' + exp_path = Path('output') / f'mn_st_{EPOCH}_{weight_hidden_size}{config_str}' for seed in range(n_seeds): seed_path = exp_path / str(seed) @@ -161,7 +161,7 @@ if __name__ == '__main__': except RuntimeError: pass - accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch).item() + accuracy = checkpoint_and_validate(metanet, valid_loader, seed_path, epoch, keep_n=5).item() validation_log = dict(Epoch=int(epoch), Batch=BATCHSIZE, Metric=f'Test {VAL_METRIC_NAME}', Score=accuracy) train_store.loc[train_store.shape[0]] = validation_log diff --git a/sanity_check_weights.py b/sanity_check_weights.py index 4caa5ef..834bdbf 100644 --- a/sanity_check_weights.py +++ b/sanity_check_weights.py @@ -15,13 +15,12 @@ from functionalities_test import epsilon_error_margin as e from network import MetaNet, MetaNetCompareBaseline -def extract_weights_from_model(model:MetaNet)->dict: - inpt = torch.zeros(5) +def extract_weights_from_model(model: MetaNet) -> dict: + inpt = torch.zeros(5, device=next(model.parameters()).device, dtype=torch.float) inpt[-1] = 1 - inpt.long() weights = defaultdict(list) - layers = [layer.particles for layer in [model._meta_layer_first, *model._meta_layer_list, model._meta_layer_last]] + layers = [layer.particles for layer in model.all_layers] for i, layer in enumerate(layers): for net in layer: weights[i].append(net(inpt).detach()) @@ -29,9 +28,10 @@ def extract_weights_from_model(model:MetaNet)->dict: def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics.Accuracy): + meta_net_device = next(meta_net.parameters()).device transfer_net = MetaNetCompareBaseline(meta_net.interface, depth=meta_net.depth, width=meta_net.width, out=meta_net.out, - residual_skip=meta_net.residual_skip) + residual_skip=meta_net.residual_skip).to(meta_net_device) with torch.no_grad(): new_weight_values = list(new_weights.values()) old_parameters = list(transfer_net.parameters()) @@ -45,7 +45,7 @@ def test_weights_as_model(meta_net, new_weights, data, metric_class=torchmetrics net.eval() metric = metric_class() for batch, (batch_x, batch_y) in tqdm(enumerate(data), total=len(data), desc='Test Batch: '): - y = net(batch_x) + y = net(batch_x.to(meta_net_device)) metric(y.cpu(), batch_y.cpu()) # metric on all batches using custom accumulation