From 54590eb147e729ccdf6cc25cc59c9bb664951e16 Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Sun, 23 May 2021 10:33:54 +0200 Subject: [PATCH] application losses --- journal_robustness.py | 94 +++++++++++++++++++++++-------------------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/journal_robustness.py b/journal_robustness.py index 51b82d5..96c7942 100644 --- a/journal_robustness.py +++ b/journal_robustness.py @@ -12,17 +12,18 @@ from sklearn.metrics import mean_squared_error as MSE from journal_basins import mean_invariate_manhattan_distance as MIM from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent from network import Net +from torch.nn import functional as F from visualization import plot_loss, bar_chart_fixpoints def prng(): return random.random() -def generate_fixpoint_weights(): - return torch.tensor([ [1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], - [1.0], [0.0], [0.0], [0.0], - [1.0], [0.0] - ], dtype=torch.float32) +def generate_perfekt_synthetic_fixpoint_weights(): + return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], + [1.0], [0.0], [0.0], [0.0], + [1.0], [0.0] + ], dtype=torch.float32) class RobustnessComparisonExperiment: @@ -39,7 +40,7 @@ class RobustnessComparisonExperiment: network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise else: network.state_dict()[layer_name][line_id][weight_id] = weight_value - noise - + return network def __init__(self, population_size, log_step_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate, @@ -53,7 +54,6 @@ class RobustnessComparisonExperiment: self.epochs = epochs self.ST_steps = st_steps self.loss_history = [] - self.nets = [] self.synthetic = synthetic self.fixpoint_counters = { "identity_func": 0, @@ -68,7 +68,7 @@ class RobustnessComparisonExperiment: self.directory.mkdir(parents=True, exist_ok=True) self.id_functions = [] - self.populate_environment() + self.nets = self.populate_environment() self.count_fixpoints() self.time_to_vergence, self.time_as_fixpoint = self.test_robustness() @@ -76,64 +76,72 @@ class RobustnessComparisonExperiment: def populate_environment(self): loop_population_size = tqdm(range(self.population_size)) + nets = [] for i in loop_population_size: loop_population_size.set_description("Populating experiment %s" % i) - + if self.synthetic: ''' Either use perfect / hand-constructed fixpoint ... ''' net_name = f"net_{str(i)}_synthetic" net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name) - net.apply_weights(generate_fixpoint_weights()) - + net.apply_weights(generate_perfekt_synthetic_fixpoint_weights()) + else: ''' .. or use natural approach to train fixpoints from random initialisation. ''' net_name = f"net_{str(i)}" net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name) for _ in range(self.epochs): - for _ in range(self.ST_steps): - net.self_train(1, self.log_step_size, self.net_learning_rate) - - self.nets.append(net) - + net.self_train(self.ST_steps, self.log_step_size, self.net_learning_rate) + nets.append(net) + return nets def test_robustness(self, print_it=True): - time_to_vergence = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] - time_as_fixpoint = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] + avg_time_to_vergence = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] + avg_time_as_fixpoint = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] avg_loss_per_application = [[0 for _ in range(10)] for _ in range(len(self.id_functions))] noise_range = range(10) row_headers = [] - + for i, fixpoint in enumerate(self.id_functions): row_headers.append(fixpoint.name) - for noise_level in noise_range: - application_losses = [] + for seed in range(10): + for noise_level in noise_range: + application_losses = [] + + clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, + f"{fixpoint.name}_clone_noise10e-{noise_level}") + clone.load_state_dict(copy.deepcopy(fixpoint.state_dict())) + rand_noise = prng() * pow(10, -noise_level) + clone = self.apply_noise(clone, rand_noise) + + while not is_zero_fixpoint(clone) and not is_divergent(clone): + if is_identity_function(clone): + avg_time_as_fixpoint[i][noise_level] += 1 + + # -> before + clone_weight_pre_application = clone.input_weight_matrix() + target_data_pre_application = clone.create_target_weights(clone_weight_pre_application) + + clone.self_application(1, self.log_step_size) + avg_time_to_vergence[i][noise_level] += 1 + # -> after + clone_weight_post_application = clone.input_weight_matrix() + target_data_post_application = clone.create_target_weights(clone_weight_post_application) + + application_losses.append(F.l1_loss(target_data_pre_application, target_data_post_application)) + - clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, - f"{fixpoint.name}_clone_noise10e-{noise_level}") - clone.load_state_dict(copy.deepcopy(fixpoint.state_dict())) - rand_noise = prng() * pow(10, -noise_level) - clone = self.apply_noise(clone, rand_noise) - - while not is_zero_fixpoint(clone) and not is_divergent(clone): - if is_identity_function(clone): - time_as_fixpoint[i][noise_level] += 1 - - # Todo: what kind of comparison between application? -> before - clone.self_application(1, self.log_step_size) - time_to_vergence[i][noise_level] += 1 - # -> after - if print_it: col_headers = [str(f"10e-{d}") for d in noise_range] - - print(f"\nAppplications steps until divergence / zero: ") - print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) - - print(f"\nTime as fixpoint: ") - print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) - return time_as_fixpoint, time_to_vergence + print(f"\nAppplications steps until divergence / zero: ") + print(tabulate(avg_time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) + + print(f"\nTime as fixpoint: ") + print(tabulate(avg_time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) + + return avg_time_as_fixpoint, avg_time_to_vergence def count_fixpoints(self):