application losses
This commit is contained in:
		| @@ -12,17 +12,18 @@ from sklearn.metrics import mean_squared_error as MSE | ||||
| from journal_basins import mean_invariate_manhattan_distance as MIM | ||||
| from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent | ||||
| from network import Net | ||||
| from torch.nn import functional as F | ||||
| from visualization import plot_loss, bar_chart_fixpoints | ||||
|  | ||||
|  | ||||
| def prng(): | ||||
|     return random.random() | ||||
|  | ||||
| def generate_fixpoint_weights(): | ||||
|     return torch.tensor([ [1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], | ||||
|                           [1.0], [0.0], [0.0], [0.0], | ||||
|                           [1.0], [0.0] | ||||
|                         ], dtype=torch.float32) | ||||
| def generate_perfekt_synthetic_fixpoint_weights(): | ||||
|     return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], | ||||
|                          [1.0], [0.0], [0.0], [0.0], | ||||
|                          [1.0], [0.0] | ||||
|                          ], dtype=torch.float32) | ||||
|  | ||||
|  | ||||
| class RobustnessComparisonExperiment: | ||||
| @@ -39,7 +40,7 @@ class RobustnessComparisonExperiment: | ||||
|                         network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise | ||||
|                     else: | ||||
|                         network.state_dict()[layer_name][line_id][weight_id] = weight_value - noise | ||||
|                      | ||||
|  | ||||
|         return network | ||||
|  | ||||
|     def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate, | ||||
| @@ -53,7 +54,6 @@ class RobustnessComparisonExperiment: | ||||
|         self.epochs = epochs | ||||
|         self.ST_steps = st_steps | ||||
|         self.loss_history = [] | ||||
|         self.nets = [] | ||||
|         self.synthetic = synthetic | ||||
|         self.fixpoint_counters = { | ||||
|             "identity_func": 0, | ||||
| @@ -68,7 +68,7 @@ class RobustnessComparisonExperiment: | ||||
|         self.directory.mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
|         self.id_functions = [] | ||||
|         self.populate_environment() | ||||
|         self.nets = self.populate_environment() | ||||
|         self.count_fixpoints() | ||||
|         self.time_to_vergence, self.time_as_fixpoint = self.test_robustness() | ||||
|  | ||||
| @@ -76,64 +76,72 @@ class RobustnessComparisonExperiment: | ||||
|  | ||||
|     def populate_environment(self): | ||||
|         loop_population_size = tqdm(range(self.population_size)) | ||||
|         nets = [] | ||||
|  | ||||
|         for i in loop_population_size: | ||||
|             loop_population_size.set_description("Populating experiment %s" % i) | ||||
|              | ||||
|  | ||||
|             if self.synthetic: | ||||
|                 ''' Either use perfect / hand-constructed fixpoint ... ''' | ||||
|                 net_name = f"net_{str(i)}_synthetic" | ||||
|                 net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name) | ||||
|                 net.apply_weights(generate_fixpoint_weights()) | ||||
|              | ||||
|                 net.apply_weights(generate_perfekt_synthetic_fixpoint_weights()) | ||||
|  | ||||
|             else: | ||||
|                 ''' .. or use natural approach to train fixpoints from random initialisation. ''' | ||||
|                 net_name = f"net_{str(i)}" | ||||
|                 net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name) | ||||
|                 for _ in range(self.epochs): | ||||
|                     for _ in range(self.ST_steps): | ||||
|                         net.self_train(1, self.log_step_size, self.net_learning_rate) | ||||
|  | ||||
|             self.nets.append(net) | ||||
|  | ||||
|                     net.self_train(self.ST_steps, self.log_step_size, self.net_learning_rate) | ||||
|             nets.append(net) | ||||
|         return nets | ||||
|  | ||||
|     def test_robustness(self, print_it=True): | ||||
|         time_to_vergence = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         time_as_fixpoint = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         avg_time_to_vergence = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         avg_time_as_fixpoint = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         avg_loss_per_application = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] | ||||
|         noise_range = range(10) | ||||
|         row_headers = [] | ||||
|          | ||||
|  | ||||
|         for i, fixpoint in enumerate(self.id_functions): | ||||
|             row_headers.append(fixpoint.name) | ||||
|             for noise_level in noise_range: | ||||
|                 application_losses = [] | ||||
|             for seed in range(10): | ||||
|                 for noise_level in noise_range: | ||||
|                     application_losses = [] | ||||
|  | ||||
|                     clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, | ||||
|                                 f"{fixpoint.name}_clone_noise10e-{noise_level}") | ||||
|                     clone.load_state_dict(copy.deepcopy(fixpoint.state_dict())) | ||||
|                     rand_noise = prng() * pow(10, -noise_level) | ||||
|                     clone = self.apply_noise(clone, rand_noise) | ||||
|  | ||||
|                     while not is_zero_fixpoint(clone) and not is_divergent(clone): | ||||
|                         if is_identity_function(clone): | ||||
|                             avg_time_as_fixpoint[i][noise_level] += 1 | ||||
|  | ||||
|                         # -> before | ||||
|                         clone_weight_pre_application = clone.input_weight_matrix() | ||||
|                         target_data_pre_application = clone.create_target_weights(clone_weight_pre_application) | ||||
|  | ||||
|                         clone.self_application(1, self.log_step_size) | ||||
|                         avg_time_to_vergence[i][noise_level] += 1 | ||||
|                         # -> after | ||||
|                         clone_weight_post_application = clone.input_weight_matrix() | ||||
|                         target_data_post_application = clone.create_target_weights(clone_weight_post_application) | ||||
|  | ||||
|                         application_losses.append(F.l1_loss(target_data_pre_application, target_data_post_application)) | ||||
|  | ||||
|  | ||||
|                 clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, | ||||
|                             f"{fixpoint.name}_clone_noise10e-{noise_level}") | ||||
|                 clone.load_state_dict(copy.deepcopy(fixpoint.state_dict())) | ||||
|                 rand_noise = prng() * pow(10, -noise_level) | ||||
|                 clone = self.apply_noise(clone, rand_noise) | ||||
|                  | ||||
|                 while not is_zero_fixpoint(clone) and not is_divergent(clone): | ||||
|                     if is_identity_function(clone): | ||||
|                         time_as_fixpoint[i][noise_level] += 1     | ||||
|                      | ||||
|                     # Todo: what kind of comparison between application? -> before    | ||||
|                     clone.self_application(1, self.log_step_size)                     | ||||
|                     time_to_vergence[i][noise_level] += 1 | ||||
|                     # -> after | ||||
|          | ||||
|         if print_it: | ||||
|             col_headers = [str(f"10e-{d}") for d in noise_range] | ||||
|              | ||||
|             print(f"\nAppplications steps until divergence / zero: ") | ||||
|             print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|              | ||||
|             print(f"\nTime as fixpoint: ") | ||||
|             print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|  | ||||
|         return time_as_fixpoint, time_to_vergence | ||||
|             print(f"\nAppplications steps until divergence / zero: ") | ||||
|             print(tabulate(avg_time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|  | ||||
|             print(f"\nTime as fixpoint: ") | ||||
|             print(tabulate(avg_time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) | ||||
|  | ||||
|         return avg_time_as_fixpoint, avg_time_to_vergence | ||||
|  | ||||
|  | ||||
|     def count_fixpoints(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 steffen-illium
					steffen-illium