From 61ae8c2ee5574bb6af40fcffe2ecf1341082a1fe Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Fri, 4 Jun 2021 14:13:38 +0200 Subject: [PATCH] robustness fixed --- journal_robustness.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/journal_robustness.py b/journal_robustness.py index c6c6ff6..1dda85b 100644 --- a/journal_robustness.py +++ b/journal_robustness.py @@ -131,6 +131,8 @@ class RobustnessComparisonExperiment: for i, fixpoint in enumerate(self.id_functions): # 1 / n row_headers.append(fixpoint.name) for seed in range(seeds): # n / 1 + setting = seed if self.is_synthetic else i + for noise_level in range(noise_levels): steps = 0 clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, @@ -145,22 +147,22 @@ class RobustnessComparisonExperiment: target_data_pre_application = clone.create_target_weights(clone_weight_pre_application) clone.self_application(1, self.log_step_size) - time_to_vergence[i][noise_level] += 1 + time_to_vergence[setting][noise_level] += 1 # -> after clone_weight_post_application = clone.input_weight_matrix() target_data_post_application = clone.create_target_weights(clone_weight_post_application) absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item() - setting = seed if self.is_synthetic else i if is_identity_function(clone): - time_as_fixpoint[i][noise_level] += 1 + time_as_fixpoint[setting][noise_level] += 1 # When this raises a Type Error, we found a second order fixpoint! steps += 1 df.loc[df.shape[0]] = [setting, noise_level, steps, absolute_loss, - time_to_vergence[i][noise_level], time_as_fixpoint[i][noise_level]] + time_to_vergence[setting][noise_level], + time_as_fixpoint[setting][noise_level]] pbar.update(1) # Get the measuremts at the highest time_time_to_vergence @@ -189,10 +191,10 @@ class RobustnessComparisonExperiment: col_headers = [str(f"10e-{d}") for d in range(noise_levels)] print(f"\nAppplications steps until divergence / zero: ") - print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) + # print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) print(f"\nTime as fixpoint: ") - print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) + # print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) return time_as_fixpoint, time_to_vergence