robustness fixed
This commit is contained in:
@ -131,6 +131,8 @@ class RobustnessComparisonExperiment:
|
|||||||
for i, fixpoint in enumerate(self.id_functions): # 1 / n
|
for i, fixpoint in enumerate(self.id_functions): # 1 / n
|
||||||
row_headers.append(fixpoint.name)
|
row_headers.append(fixpoint.name)
|
||||||
for seed in range(seeds): # n / 1
|
for seed in range(seeds): # n / 1
|
||||||
|
setting = seed if self.is_synthetic else i
|
||||||
|
|
||||||
for noise_level in range(noise_levels):
|
for noise_level in range(noise_levels):
|
||||||
steps = 0
|
steps = 0
|
||||||
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||||
@ -145,22 +147,22 @@ class RobustnessComparisonExperiment:
|
|||||||
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
|
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
|
||||||
|
|
||||||
clone.self_application(1, self.log_step_size)
|
clone.self_application(1, self.log_step_size)
|
||||||
time_to_vergence[i][noise_level] += 1
|
time_to_vergence[setting][noise_level] += 1
|
||||||
# -> after
|
# -> after
|
||||||
clone_weight_post_application = clone.input_weight_matrix()
|
clone_weight_post_application = clone.input_weight_matrix()
|
||||||
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
||||||
|
|
||||||
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
|
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
|
||||||
|
|
||||||
setting = seed if self.is_synthetic else i
|
|
||||||
|
|
||||||
if is_identity_function(clone):
|
if is_identity_function(clone):
|
||||||
time_as_fixpoint[i][noise_level] += 1
|
time_as_fixpoint[setting][noise_level] += 1
|
||||||
# When this raises a Type Error, we found a second order fixpoint!
|
# When this raises a Type Error, we found a second order fixpoint!
|
||||||
steps += 1
|
steps += 1
|
||||||
|
|
||||||
df.loc[df.shape[0]] = [setting, noise_level, steps, absolute_loss,
|
df.loc[df.shape[0]] = [setting, noise_level, steps, absolute_loss,
|
||||||
time_to_vergence[i][noise_level], time_as_fixpoint[i][noise_level]]
|
time_to_vergence[setting][noise_level],
|
||||||
|
time_as_fixpoint[setting][noise_level]]
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
# Get the measuremts at the highest time_time_to_vergence
|
# Get the measuremts at the highest time_time_to_vergence
|
||||||
@ -189,10 +191,10 @@ class RobustnessComparisonExperiment:
|
|||||||
col_headers = [str(f"10e-{d}") for d in range(noise_levels)]
|
col_headers = [str(f"10e-{d}") for d in range(noise_levels)]
|
||||||
|
|
||||||
print(f"\nAppplications steps until divergence / zero: ")
|
print(f"\nAppplications steps until divergence / zero: ")
|
||||||
print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
# print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||||
|
|
||||||
print(f"\nTime as fixpoint: ")
|
print(f"\nTime as fixpoint: ")
|
||||||
print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
# print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
|
||||||
|
|
||||||
return time_as_fixpoint, time_to_vergence
|
return time_as_fixpoint, time_to_vergence
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user