robustness cleaned

This commit is contained in:
steffen-illium
2021-06-04 14:03:03 +02:00
parent c9efe0a31b
commit 9abde030af

View File

@ -4,7 +4,6 @@ import pandas as pd
import torch import torch
import random import random
import copy import copy
import numpy as np
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
@ -21,6 +20,7 @@ from matplotlib import pyplot as plt
def prng(): def prng():
return random.random() return random.random()
def generate_perfekt_synthetic_fixpoint_weights(): def generate_perfekt_synthetic_fixpoint_weights():
return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0],
[1.0], [0.0], [0.0], [0.0], [1.0], [0.0], [0.0], [0.0],
@ -28,15 +28,32 @@ def generate_perfekt_synthetic_fixpoint_weights():
], dtype=torch.float32) ], dtype=torch.float32)
PALETTE = 10 * (
"#377eb8",
"#4daf4a",
"#984ea3",
"#e41a1c",
"#ff7f00",
"#a65628",
"#f781bf",
"#888888",
"#a6cee3",
"#b2df8a",
"#cab2d6",
"#fb9a99",
"#fdbf6f",
)
class RobustnessComparisonExperiment: class RobustnessComparisonExperiment:
@staticmethod @staticmethod
def apply_noise(network, noise: int): def apply_noise(network, noise: int):
""" Changing the weights of a network to values + noise """ # Changing the weights of a network to values + noise
for layer_id, layer_name in enumerate(network.state_dict()): for layer_id, layer_name in enumerate(network.state_dict()):
for line_id, line_values in enumerate(network.state_dict()[layer_name]): for line_id, line_values in enumerate(network.state_dict()[layer_name]):
for weight_id, weight_value in enumerate(network.state_dict()[layer_name][line_id]): for weight_id, weight_value in enumerate(network.state_dict()[layer_name][line_id]):
#network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise # network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
if prng() < 0.5: if prng() < 0.5:
network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise network.state_dict()[layer_name][line_id][weight_id] = weight_value + noise
else: else:
@ -55,7 +72,7 @@ class RobustnessComparisonExperiment:
self.epochs = epochs self.epochs = epochs
self.ST_steps = st_steps self.ST_steps = st_steps
self.loss_history = [] self.loss_history = []
self.synthetic = synthetic self.is_synthetic = synthetic
self.fixpoint_counters = { self.fixpoint_counters = {
"identity_func": 0, "identity_func": 0,
"divergent": 0, "divergent": 0,
@ -71,14 +88,14 @@ class RobustnessComparisonExperiment:
self.id_functions = [] self.id_functions = []
self.nets = self.populate_environment() self.nets = self.populate_environment()
self.count_fixpoints() self.count_fixpoints()
self.time_to_vergence, self.time_as_fixpoint = self.test_robustness() self.time_to_vergence, self.time_as_fixpoint = self.test_robustness(
seeds=population_size if self.is_synthetic else 1)
self.save() self.save()
def populate_environment(self): def populate_environment(self):
loop_population_size = tqdm(range(self.population_size))
nets = [] nets = []
if self.synthetic: if self.is_synthetic:
''' Either use perfect / hand-constructed fixpoint ... ''' ''' Either use perfect / hand-constructed fixpoint ... '''
net_name = f"net_{str(0)}_synthetic" net_name = f"net_{str(0)}_synthetic"
net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name) net = Net(self.net_input_size, self.net_hidden_size, self.net_out_size, net_name)
@ -86,6 +103,7 @@ class RobustnessComparisonExperiment:
nets.append(net) nets.append(net)
else: else:
loop_population_size = tqdm(range(self.population_size))
for i in loop_population_size: for i in loop_population_size:
loop_population_size.set_description("Populating experiment %s" % i) loop_population_size.set_description("Populating experiment %s" % i)
@ -99,58 +117,61 @@ class RobustnessComparisonExperiment:
def test_robustness(self, print_it=True, noise_levels=10, seeds=10): def test_robustness(self, print_it=True, noise_levels=10, seeds=10):
assert (len(self.id_functions) == 1 and seeds > 1) or (len(self.id_functions) > 1 and seeds == 1) assert (len(self.id_functions) == 1 and seeds > 1) or (len(self.id_functions) > 1 and seeds == 1)
is_synthetic = True if len(self.id_functions) > 1 and seeds == 1 else False time_to_vergence = [[0 for _ in range(noise_levels)] for _ in
avg_time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(seeds if self.is_synthetic else len(self.id_functions))]
range(seeds if is_synthetic else len(self.id_functions))] time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in
avg_time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(seeds if self.is_synthetic else len(self.id_functions))]
range(seeds if is_synthetic else len(self.id_functions))]
row_headers = [] row_headers = []
data_pos = 0
# This checks wether to use synthetic setting with multiple seeds # This checks wether to use synthetic setting with multiple seeds
# or multi network settings with a singlee seed # or multi network settings with a singlee seed
df = pd.DataFrame(columns=['setting', 'noise_level', 'application_step', 'absolute_loss', 'time_to_vergence']) df = pd.DataFrame(columns=['setting', 'noise_level', 'steps', 'absolute_loss', 'time_to_vergence', 'time_as_fixpoint'])
for i, fixpoint in enumerate(self.id_functions): #1 / n with tqdm(total=max(len(self.id_functions), seeds)) as pbar:
row_headers.append(fixpoint.name) for i, fixpoint in enumerate(self.id_functions): # 1 / n
for seed in range(seeds): #n / 1 row_headers.append(fixpoint.name)
for noise_level in range(noise_levels): for seed in range(seeds): # n / 1
self_application_steps = 0 for noise_level in range(noise_levels):
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size, steps = 0
f"{fixpoint.name}_clone_noise10e-{noise_level}") clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict())) f"{fixpoint.name}_clone_noise10e-{noise_level}")
rand_noise = prng() * pow(10, -noise_level) #n / 1 clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
clone = self.apply_noise(clone, rand_noise) rand_noise = prng() * pow(10, -noise_level) # n / 1
clone = self.apply_noise(clone, rand_noise)
while not is_zero_fixpoint(clone) and not is_divergent(clone): while not is_zero_fixpoint(clone) and not is_divergent(clone):
# -> before # -> before
clone_weight_pre_application = clone.input_weight_matrix() clone_weight_pre_application = clone.input_weight_matrix()
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application) target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
clone.self_application(1, self.log_step_size) clone.self_application(1, self.log_step_size)
avg_time_to_vergence[i][noise_level] += 1 time_to_vergence[i][noise_level] += 1
# -> after # -> after
clone_weight_post_application = clone.input_weight_matrix() clone_weight_post_application = clone.input_weight_matrix()
target_data_post_application = clone.create_target_weights(clone_weight_post_application) target_data_post_application = clone.create_target_weights(clone_weight_post_application)
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item() absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
setting = i if is_synthetic else seed setting = seed if self.is_synthetic else i
if is_identity_function(clone): if is_identity_function(clone):
avg_time_as_fixpoint[i][noise_level] += 1 time_as_fixpoint[i][noise_level] += 1
# When this raises a Type Error, we found a second order fixpoint! # When this raises a Type Error, we found a second order fixpoint!
self_application_steps += 1 steps += 1
else:
self_application_steps = pd.NA # Not a Number!
df.loc[df.shape[0]] = [setting, noise_level, self_application_steps, df.loc[df.shape[0]] = [setting, noise_level, steps, absolute_loss,
absolute_loss, avg_time_to_vergence[i][noise_level]] time_to_vergence[i][noise_level], time_as_fixpoint[i][noise_level]]
pbar.update(1)
# Get the measuremts at the highest time_time_to_vergence
# calculate the average: df_sorted = df.sort_values('steps', ascending=False).drop_duplicates(['setting', 'noise_level'])
# df = df.replace([np.inf, -np.inf], np.nan) df_melted = df_sorted.reset_index().melt(id_vars=['setting', 'noise_level', 'steps'],
# df = df.dropna() value_vars=['time_to_vergence', 'time_as_fixpoint'],
bf = sns.boxplot(data=df, y='self_application_steps', x='noise_level', ) var_name="Measurement",
value_name="Steps")
# Plotting
sns.set(style='whitegrid')
bf = sns.boxplot(data=df_melted, y='Steps', x='noise_level', hue='Measurement', palette=PALETTE)
bf.set_title('Robustness as self application steps per noise level') bf.set_title('Robustness as self application steps per noise level')
plt.tight_layout() plt.tight_layout()
@ -158,6 +179,7 @@ class RobustnessComparisonExperiment:
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box', # bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
# col='noise_level', col_wrap=3, showfliers=False) # col='noise_level', col_wrap=3, showfliers=False)
directory = Path('output') / 'robustness' directory = Path('output') / 'robustness'
directory.mkdir(parents=True, exist_ok=True)
filename = f"absolute_loss_perapplication_boxplot_grid.png" filename = f"absolute_loss_perapplication_boxplot_grid.png"
filepath = directory / filename filepath = directory / filename
@ -167,12 +189,12 @@ class RobustnessComparisonExperiment:
col_headers = [str(f"10e-{d}") for d in range(noise_levels)] col_headers = [str(f"10e-{d}") for d in range(noise_levels)]
print(f"\nAppplications steps until divergence / zero: ") print(f"\nAppplications steps until divergence / zero: ")
print(tabulate(avg_time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) print(tabulate(time_to_vergence, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
print(f"\nTime as fixpoint: ") print(f"\nTime as fixpoint: ")
print(tabulate(avg_time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl')) print(tabulate(time_as_fixpoint, showindex=row_headers, headers=col_headers, tablefmt='orgtbl'))
return avg_time_as_fixpoint, avg_time_to_vergence return time_as_fixpoint, time_to_vergence
def count_fixpoints(self): def count_fixpoints(self):
exp_details = f"ST steps: {self.ST_steps}" exp_details = f"ST steps: {self.ST_steps}"
@ -198,7 +220,7 @@ if __name__ == "__main__":
ST_steps = 1000 ST_steps = 1000
ST_epochs = 5 ST_epochs = 5
ST_log_step_size = 10 ST_log_step_size = 10
ST_population_size = 5 ST_population_size = 100
ST_net_hidden_size = 2 ST_net_hidden_size = 2
ST_net_learning_rate = 0.04 ST_net_learning_rate = 0.04
ST_name_hash = random.getrandbits(32) ST_name_hash = random.getrandbits(32)