robustness fixed
This commit is contained in:
		| @@ -131,6 +131,8 @@ class RobustnessComparisonExperiment: | ||||
|             for i, fixpoint in enumerate(self.id_functions):  # 1 / n | ||||
|                 row_headers.append(fixpoint.name) | ||||
|                 for seed in range(seeds):  # n / 1 | ||||
|                     setting = seed if self.is_synthetic else i | ||||
|  | ||||
|                     for noise_level in range(noise_levels): | ||||
|                         steps = 0 | ||||
|                         clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, | ||||
| @@ -145,22 +147,22 @@ class RobustnessComparisonExperiment: | ||||
|                             target_data_pre_application = clone.create_target_weights(clone_weight_pre_application) | ||||
|  | ||||
|                             clone.self_application(1, self.log_step_size) | ||||
|                             time_to_vergence[i][noise_level] += 1 | ||||
|                             time_to_vergence[setting][noise_level] += 1 | ||||
|                             # -> after | ||||
|                             clone_weight_post_application = clone.input_weight_matrix() | ||||
|                             target_data_post_application = clone.create_target_weights(clone_weight_post_application) | ||||
|  | ||||
|                             absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item() | ||||
|  | ||||
|                             setting = seed if self.is_synthetic else i | ||||
|  | ||||
|                             if is_identity_function(clone): | ||||
|                                 time_as_fixpoint[i][noise_level] += 1 | ||||
|                                 time_as_fixpoint[setting][noise_level] += 1 | ||||
|                                 # When this raises a Type Error, we found a second order fixpoint! | ||||
|                             steps += 1 | ||||
|  | ||||
|                             df.loc[df.shape[0]] = [setting, noise_level, steps, absolute_loss, | ||||
|                                                    time_to_vergence[i][noise_level], time_as_fixpoint[i][noise_level]] | ||||
|                                                    time_to_vergence[setting][noise_level], | ||||
|                                                    time_as_fixpoint[setting][noise_level]] | ||||
|                     pbar.update(1) | ||||
|  | ||||
|         # Get the measuremts at the highest time_time_to_vergence | ||||
| @@ -189,10 +191,10 @@ class RobustnessComparisonExperiment: | ||||
|             col_headers = [str(f"10e-{d}") for d in range(noise_levels)] | ||||
|  | ||||
|             print(f"\nAppplications steps until divergence / zero: ") | ||||
|             print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|             # print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|  | ||||
|             print(f"\nTime as fixpoint: ") | ||||
|             print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|             # print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|  | ||||
|         return time_as_fixpoint, time_to_vergence | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 steffen-illium
					steffen-illium