journal_robustness.py redone, now is sensitive to seeds and plots

This commit is contained in:
steffen-illium 2021-05-23 15:49:48 +02:00
parent 55bdd706b6
commit 5e5511caf8
2 changed files with 10 additions and 8 deletions

View File

@ -9,10 +9,7 @@ import numpy as np
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from tabulate import tabulate from tabulate import tabulate
from sklearn.metrics import mean_absolute_error as MAE
from sklearn.metrics import mean_squared_error as MSE
from journal_basins import mean_invariate_manhattan_distance as MIM
from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent from functionalities_test import is_identity_function, is_zero_fixpoint, test_for_fixpoints, is_divergent
from network import Net from network import Net
from torch.nn import functional as F from torch.nn import functional as F
@ -153,7 +150,11 @@ class RobustnessComparisonExperiment:
# sns.set(rc={'figure.figsize': (10, 50)}) # sns.set(rc={'figure.figsize': (10, 50)})
bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box', bx = sns.catplot(data=df[df['absolute_loss'] < 1], y='absolute_loss', x='application_step', kind='box',
col='noise_level', col_wrap=3, showfliers=False) col='noise_level', col_wrap=3, showfliers=False)
plt.show() directory = Path('output') / 'robustness'
filename = f"absolute_loss_perapplication_boxplot_grid.png"
filepath = directory / filename
plt.savefig(str(filepath))
if print_it: if print_it:
col_headers = [str(f"10e-{d}") for d in range(noise_levels)] col_headers = [str(f"10e-{d}") for d in range(noise_levels)]

View File

@ -131,10 +131,10 @@ def plot_3d(matrices_weights_history, directory: Union[str, Path], population_si
for j in range(start_log_time, len(weight_matrix_pca)): for j in range(start_log_time, len(weight_matrix_pca)):
xdata.append(weight_matrix_pca[j][0]) xdata.append(weight_matrix_pca[j][0])
ydata.append(weight_matrix_pca[j][1]) ydata.append(weight_matrix_pca[j][1])
zdata = np.arange(start_time, len(ydata)*batch_size+start_time, batch_size).tolist() zdata = np.arange(start_time, len(ydata)*batch_size+start_time, batch_size)
ax.plot3D(xdata, ydata, zdata, label=f"net {i}") ax.plot3D(xdata, ydata, zdata, label=f"net {i}")
ax.scatter(np.array(xdata), np.array(ydata), np.array(zdata), s=7) ax.scatter(np.asarray(xdata), np.asarray(ydata), zdata, s=7)
steps = mpatches.Patch(color="white", label=f"{z_axis_legend}: {len(matrices_weights_history)} steps") steps = mpatches.Patch(color="white", label=f"{z_axis_legend}: {len(matrices_weights_history)} steps")
population_size = mpatches.Patch(color="white", label=f"Population: {population_size} networks") population_size = mpatches.Patch(color="white", label=f"Population: {population_size} networks")
@ -181,7 +181,8 @@ def plot_3d_self_train(nets_array: List, exp_name: str, directory: Union[str, Pa
z_axis_legend = "epochs" z_axis_legend = "epochs"
return plot_3d(matrices_weights_history, directory, len(nets_array), z_axis_legend, exp_name, "", batch_size, plot_pca_together=plot_pca_together) return plot_3d(matrices_weights_history, directory, len(nets_array), z_axis_legend, exp_name, "", batch_size,
plot_pca_together=plot_pca_together)
def plot_3d_self_application(nets_array: List, exp_name: str, directory_name: Union[str, Path], batch_size: int) -> None: def plot_3d_self_application(nets_array: List, exp_name: str, directory_name: Union[str, Path], batch_size: int) -> None:
@ -212,7 +213,7 @@ def plot_3d_soup(nets_list, exp_name, directory: Union[str, Path]):
# will send forward the number "1" for batch size with the variable <irrelevant_batch_size>. # will send forward the number "1" for batch size with the variable <irrelevant_batch_size>.
irrelevant_batch_size = 1 irrelevant_batch_size = 1
plot_3d_self_train(nets_list, exp_name, directory, irrelevant_batch_size) plot_3d_self_train(nets_list, exp_name, directory, irrelevant_batch_size, False)
def line_chart_fixpoints(fixpoint_counters_history: list, epochs: int, ST_steps_between_SA: int, def line_chart_fixpoints(fixpoint_counters_history: list, epochs: int, ST_steps_between_SA: int,