robustness
This commit is contained in:
119
experiments/robustness_tester.py
Normal file
119
experiments/robustness_tester.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent
|
||||||
|
from network import Net
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from visualization import plot_loss, bar_chart_fixpoints
|
||||||
|
import seaborn as sns
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def prng():
|
||||||
|
return random.random()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_perfekt_synthetic_fixpoint_weights():
|
||||||
|
return torch.tensor([[1.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0],
|
||||||
|
[1.0], [0.0], [0.0], [0.0],
|
||||||
|
[1.0], [0.0]
|
||||||
|
], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
PALETTE = 10 * (
|
||||||
|
"#377eb8",
|
||||||
|
"#4daf4a",
|
||||||
|
"#984ea3",
|
||||||
|
"#e41a1c",
|
||||||
|
"#ff7f00",
|
||||||
|
"#a65628",
|
||||||
|
"#f781bf",
|
||||||
|
"#888888",
|
||||||
|
"#a6cee3",
|
||||||
|
"#b2df8a",
|
||||||
|
"#cab2d6",
|
||||||
|
"#fb9a99",
|
||||||
|
"#fdbf6f",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_robustness(networks: list, exp_path, noise_levels=10, seeds=10, log_step_size=10):
|
||||||
|
time_to_vergence = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||||
|
time_as_fixpoint = [[0 for _ in range(noise_levels)] for _ in range(len(networks))]
|
||||||
|
row_headers = []
|
||||||
|
|
||||||
|
df = pd.DataFrame(columns=['setting', 'Noise Level', 'Self Train Steps', 'absolute_loss',
|
||||||
|
'Time to convergence', 'Time as fixpoint'])
|
||||||
|
with tqdm(total=max(len(networks), seeds)) as pbar:
|
||||||
|
for setting, fixpoint in enumerate(networks): # 1 / n
|
||||||
|
row_headers.append(fixpoint.name)
|
||||||
|
for seed in range(seeds): # n / 1
|
||||||
|
for noise_level in range(noise_levels):
|
||||||
|
steps = 0
|
||||||
|
clone = Net(fixpoint.input_size, fixpoint.hidden_size, fixpoint.out_size,
|
||||||
|
f"{fixpoint.name}_clone_noise_1e-{noise_level}")
|
||||||
|
clone.load_state_dict(copy.deepcopy(fixpoint.state_dict()))
|
||||||
|
clone = clone.apply_noise(pow(10, -noise_level))
|
||||||
|
|
||||||
|
while not is_zero_fixpoint(clone) and not is_divergent(clone):
|
||||||
|
# -> before
|
||||||
|
clone_weight_pre_application = clone.input_weight_matrix()
|
||||||
|
target_data_pre_application = clone.create_target_weights(clone_weight_pre_application)
|
||||||
|
|
||||||
|
clone.self_application(1, log_step_size)
|
||||||
|
time_to_vergence[setting][noise_level] += 1
|
||||||
|
# -> after
|
||||||
|
clone_weight_post_application = clone.input_weight_matrix()
|
||||||
|
target_data_post_application = clone.create_target_weights(clone_weight_post_application)
|
||||||
|
|
||||||
|
absolute_loss = F.l1_loss(target_data_pre_application, target_data_post_application).item()
|
||||||
|
|
||||||
|
if is_identity_function(clone):
|
||||||
|
time_as_fixpoint[setting][noise_level] += 1
|
||||||
|
# When this raises a Type Error, we found a second order fixpoint!
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
df.loc[df.shape[0]] = [setting, f'$\mathregular{{10^{{-{noise_level}}}}}$',
|
||||||
|
steps, absolute_loss,
|
||||||
|
time_to_vergence[setting][noise_level],
|
||||||
|
time_as_fixpoint[setting][noise_level]]
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Get the measuremts at the highest time_time_to_vergence
|
||||||
|
df_sorted = df.sort_values('Self Train Steps', ascending=False).drop_duplicates(['setting', 'Noise Level'])
|
||||||
|
df_melted = df_sorted.reset_index().melt(id_vars=['setting', 'Noise Level', 'Self Train Steps'],
|
||||||
|
value_vars=['Time to convergence', 'Time as fixpoint'],
|
||||||
|
var_name="Measurement",
|
||||||
|
value_name="Steps").sort_values('Noise Level')
|
||||||
|
# Plotting
|
||||||
|
# plt.rcParams.update({
|
||||||
|
# "text.usetex": True,
|
||||||
|
# "font.family": "sans-serif",
|
||||||
|
# "font.size": 12,
|
||||||
|
# "font.weight": 'bold',
|
||||||
|
# "font.sans-serif": ["Helvetica"]})
|
||||||
|
plt.clf()
|
||||||
|
sns.set(style='whitegrid', font_scale=1)
|
||||||
|
_ = sns.boxplot(data=df_melted, y='Steps', x='Noise Level', hue='Measurement', palette=PALETTE)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# sns.set(rc={'figure.figsize': (10, 50)})
|
||||||
|
# bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
|
||||||
|
# col='noise_level', col_wrap=3, showfliers=False)
|
||||||
|
|
||||||
|
filename = f"absolute_loss_perapplication_boxplot_grid_wild.png"
|
||||||
|
filepath = exp_path / filename
|
||||||
|
plt.savefig(str(filepath))
|
||||||
|
plt.close('all')
|
||||||
|
return time_as_fixpoint, time_to_vergence
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise NotImplementedError('Get out of here!')
|
||||||
@@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from experiments.meta_task_small_utility import AddTaskDataset, train_task
|
from experiments.meta_task_small_utility import AddTaskDataset, train_task
|
||||||
|
from experiments.robustness_tester import test_robustness
|
||||||
from network import MetaNet
|
from network import MetaNet
|
||||||
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
from functionalities_test import test_for_fixpoints, FixTypes as ft
|
||||||
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
|
from experiments.meta_task_utility import new_storage_df, flat_for_store, plot_training_result, \
|
||||||
@@ -29,12 +30,12 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
training = True
|
training = False
|
||||||
plotting = True
|
plotting = False
|
||||||
n_st = 700
|
n_st = 700
|
||||||
activation = None # nn.ReLU()
|
activation = None # nn.ReLU()
|
||||||
|
|
||||||
for weight_hidden_size in [3, 4, 5, 6]:
|
for weight_hidden_size in [3, 4, 5]:
|
||||||
|
|
||||||
tsk_threshold = 0.85
|
tsk_threshold = 0.85
|
||||||
weight_hidden_size = weight_hidden_size
|
weight_hidden_size = weight_hidden_size
|
||||||
@@ -62,7 +63,6 @@ if __name__ == '__main__':
|
|||||||
for seed in range(n_seeds):
|
for seed in range(n_seeds):
|
||||||
seed_path = exp_path / str(seed)
|
seed_path = exp_path / str(seed)
|
||||||
|
|
||||||
model_path = seed_path / '0000_trained_model.zip'
|
|
||||||
df_store_path = seed_path / 'train_store.csv'
|
df_store_path = seed_path / 'train_store.csv'
|
||||||
weight_store_path = seed_path / 'weight_store.csv'
|
weight_store_path = seed_path / 'weight_store.csv'
|
||||||
srnn_parameters = dict()
|
srnn_parameters = dict()
|
||||||
@@ -73,7 +73,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if training:
|
if training:
|
||||||
# Check if files do exist on project location, warn and break.
|
# Check if files do exist on project location, warn and break.
|
||||||
for path in [model_path, df_store_path, weight_store_path]:
|
for path in [df_store_path, weight_store_path]:
|
||||||
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
assert not path.exists(), f'Path "{path}" already exists. Check your configuration!'
|
||||||
|
|
||||||
train_data = AddTaskDataset()
|
train_data = AddTaskDataset()
|
||||||
@@ -189,6 +189,7 @@ if __name__ == '__main__':
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# noinspection PyUnboundLocalVariable
|
||||||
run_particle_dropout_and_plot(model_path, valid_loader=vali_load, metric_class=VALIDATION_METRIC)
|
run_particle_dropout_and_plot(model_path, valid_loader=vali_load, metric_class=VALIDATION_METRIC)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print('ERROR:', e)
|
print('ERROR:', e)
|
||||||
@@ -203,6 +204,12 @@ if __name__ == '__main__':
|
|||||||
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
|
plot_grouped_3d_trajectories_by_layer(model_path, weight_store_path, status_type=ft.other_func)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print('ERROR:', e)
|
print('ERROR:', e)
|
||||||
|
try:
|
||||||
|
model_path = next(seed_path.glob(f'*e{EPOCH}.tp'))
|
||||||
|
model = torch.load(model_path, map_location='cpu')
|
||||||
|
test_robustness(list(model.particles), seed_path)
|
||||||
|
except ValueError as e:
|
||||||
|
print('ERROR:', e)
|
||||||
|
|
||||||
if n_seeds >= 2:
|
if n_seeds >= 2:
|
||||||
combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
|
combined_df_store_path = exp_path.parent / f'comb_train_{exp_path.stem[:-1]}n.csv'
|
||||||
|
|||||||
Reference in New Issue
Block a user