Added distinction time-as-fixpoint and time-to-vergence to be tracked.
This commit is contained in:
		| @@ -13,6 +13,8 @@ | ||||
|      | ||||
|     - see `journal_robustness.py` for robustness test modeled after cristians robustness-exp (with the exeption that we put noise on the weights). Has `synthetic` bool to switch to hand-modeled perfect fixpoint instead of naturally trained ones.  | ||||
|  | ||||
|     - Also added two difference between the "time-as-fixpoint" and "time-to-verge" (i.e. to divergence / zero). | ||||
|  | ||||
|     - We might need to consult about the "average loss per application step", as I think application loss get gradually higher the worse the weights get. So the average might not tell us much here. | ||||
|  | ||||
| - [ ] Adjust Self Training so that it favors second order fixpoints-> Second order test implementation (?) | ||||
|   | ||||
| @@ -70,7 +70,7 @@ class RobustnessComparisonExperiment: | ||||
|         self.id_functions = [] | ||||
|         self.populate_environment() | ||||
|         self.count_fixpoints() | ||||
|         self.data = self.test_robustness() | ||||
|         self.time_to_vergence, self.time_as_fixpoint = self.test_robustness() | ||||
|  | ||||
|         self.save() | ||||
|  | ||||
| @@ -82,13 +82,13 @@ class RobustnessComparisonExperiment: | ||||
|              | ||||
|             if self.synthetic: | ||||
|                 ''' Either use perfect / hand-constructed fixpoint ... ''' | ||||
|                 net_name = f"ST_net_{str(i)}_synthetic" | ||||
|                 net_name = f"net_{str(i)}_synthetic" | ||||
|                 net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name) | ||||
|                 net.apply_weights(generate_fixpoint_weights()) | ||||
|              | ||||
|             else: | ||||
|                 ''' .. or use natural approach to train fixpoints from random initialisation. ''' | ||||
|                 net_name = f"ST_net_{str(i)}" | ||||
|                 net_name = f"net_{str(i)}" | ||||
|                 net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name) | ||||
|                 for _ in range(self.epochs): | ||||
|                     for _ in range(self.ST_steps): | ||||
| @@ -98,7 +98,8 @@ class RobustnessComparisonExperiment: | ||||
|  | ||||
|  | ||||
|     def test_robustness(self, print_it=True): | ||||
|         data = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         time_to_vergence = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         time_as_fixpoint = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         avg_loss_per_application = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         noise_range = range(10) | ||||
|         row_headers = [] | ||||
| @@ -115,19 +116,24 @@ class RobustnessComparisonExperiment: | ||||
|                 clone = self.apply_noise(clone, rand_noise) | ||||
|                  | ||||
|                 while not is_zero_fixpoint(clone) and not is_divergent(clone): | ||||
|                     if is_identity_function(clone): | ||||
|                         time_as_fixpoint[i][noise_level] += 1     | ||||
|                      | ||||
|                     # Todo: what kind of comparison between application? -> before    | ||||
|                     clone.self_application(1, self.log_step_size)                     | ||||
|                     data[i][noise_level] += 1 | ||||
|                     time_to_vergence[i][noise_level] += 1 | ||||
|                     # -> after | ||||
|          | ||||
|         if print_it: | ||||
|             print(f"Number appplications steps: ") | ||||
|             col_headers = [str(f"10e-{d}") for d in noise_range] | ||||
|             print(tabulate(data, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|          | ||||
|             # other tables here | ||||
|              | ||||
|         return data | ||||
|             print(f"\nAppplications steps until divergence / zero: ") | ||||
|             print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|              | ||||
|             print(f"\nTime as fixpoint: ") | ||||
|             print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|  | ||||
|         return time_as_fixpoint, time_to_vergence | ||||
|  | ||||
|  | ||||
|     def count_fixpoints(self): | ||||
| @@ -156,7 +162,7 @@ if __name__ == "__main__": | ||||
|     ST_steps = 1000 | ||||
|     ST_epochs = 5 | ||||
|     ST_log_step_size = 10 | ||||
|     ST_population_size = 3 | ||||
|     ST_population_size = 5 | ||||
|     ST_net_hidden_size = 2 | ||||
|     ST_net_learning_rate = 0.04 | ||||
|     ST_name_hash = random.getrandbits(32) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Maximilian Zorn
					Maximilian Zorn