robustness fixed

This commit is contained in:
steffen-illium
2021-06-04 14:13:38 +02:00
parent 800a2c8f6b
commit 61ae8c2ee5

View File

@ -131,6 +131,8 @@ class RobustnessComparisonExperiment:
for i, fixpoint in enumerate(self.id_functions): # 1 / n
row_headers.append(fixpoint.name)
for seed in range(seeds): # n / 1
setting = seed if self.is_synthetic else i
for noise_level in range(noise_levels):
steps = 0
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
@ -145,22 +147,22 @@ class RobustnessComparisonExperiment:
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
clone.self_application(1, self.log_step_size)
time_to_vergence[i][noise_level] += 1
time_to_vergence[setting][noise_level] += 1
# -> after
clone_weight_post_application = clone.input_weight_matrix()
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
setting = seed if self.is_synthetic else i
if is_identity_function(clone):
time_as_fixpoint[i][noise_level] += 1
time_as_fixpoint[setting][noise_level] += 1
# When this raises a Type Error, we found a second order fixpoint!
steps += 1
df.loc[df.shape[0]] = [setting, noise_level, steps, absolute_loss,
time_to_vergence[i][noise_level], time_as_fixpoint[i][noise_level]]
time_to_vergence[setting][noise_level],
time_as_fixpoint[setting][noise_level]]
pbar.update(1)
# Get the measuremts at the highest time_time_to_vergence
@ -189,10 +191,10 @@ class RobustnessComparisonExperiment:
col_headers = [str(f"10e-{d}") for d in range(noise_levels)]
print(f"\nAppplications steps until divergence / zero: ")
print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
# print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
print(f"\nTime as fixpoint: ")
print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
# print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
return time_as_fixpoint, time_to_vergence