Merge remote-tracking branch 'origin/journal' into journal

# Conflicts:
#	journal_basins.py
This commit is contained in:
Steffen Illium 2021-09-13 16:06:03 +02:00
commit 21dd572969

View File

@ -14,6 +14,8 @@ from sklearn.metrics import mean_squared_error as MSE
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import torch
import torch.nn.functional as F
def prng():
return random.random()
@ -259,10 +261,12 @@ if __name__ == "__main__":
df["relative_distance"] = [ (df.loc[i]["MAE_pre"] - df.loc[i]["MAE_post"])/df.loc[i]["noise"] for i in range(len(df))]
df["class"] = [ "approaching" if df.loc[i]["relative_distance"] > 0 else "distancing" if df.loc[i]["relative_distance"] < 0 else "stationary" for i in range(len(df))]
# Countplot of all fixpoint clone after training per class. Uncomment and manually adjust xticklabels if x-ax size gets too small.
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27)
# Countplot of all fixpoint clone after training per class.
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27, legend=False)
ax.set_axis_labels("Noise Levels", "Clone Fixpoints After Training Count ", fontsize=15)
ax.set_xticklabels(labels=('10e-10', '10e-9', '10e-8', '10e-7', '10e-6', '10e-5', '10e-4', '10e-3', '10e-2', '10e-1'), fontsize=15)
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
plt.legend(fontsize='large')
plt.savefig(f"{directory}/clone_status_after_countplot_{ST_name_hash}.png")
plt.clf()
@ -274,6 +278,35 @@ if __name__ == "__main__":
ax.map(sns.boxplot, "State", "Distance", "noise", linewidth=0.8, order=["MAE_pre", "MAE_post"], whis=[0, 100])
ax.set_axis_labels("", "Manhattan Distance To Parent Weights", fontsize=15)
ax.set_xticklabels(labels=('after noise application', 'after training'), fontsize=15)
plt.ticklabel_format(style='sci', axis='x')
# plt.ticklabel_format(style='sci', axis='x')
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
plt.clf()
# Catplot of children L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
df_acc = pd.DataFrame(columns=["name", "noise", "l1_acc", "Network Type"])
for i in range(len(exp_list)):
noise = exp_list[i].noise
print(f"\nNoise: {noise}")
for network in exp_list[i].nets:
is_parent = "clone" not in network.name
if is_parent:
network.apply_weights(torch.tensor(network.s_train_weights_history[int(len(network.s_train_weights_history)/2)][0]))
input_data = network.input_weight_matrix()
target_data = network.create_target_weights(input_data)
predicted_values = network(input_data)
mse_loss = F.mse_loss(target_data, predicted_values).item()
l1_loss = F.l1_loss(target_data, predicted_values).item()
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "children"]
print("MSE:", mse_loss, "\t", "L1: ", l1_loss, "\t", network.name)
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and children accuracy is too far apart this plot might not work (only shows either parents or part of the children).
ax = sns.catplot(data=df_acc, y="l1_acc", x="noise", hue="Network Type", kind="box", legend=False, showfliers=False, height=5.27, aspect=11.7/5.27, sharey=False)
ax.map(plt.axhline, y=10**-6, ls='--')
ax.map(plt.axhline, y=10**-7, ls='--')
ax.set_axis_labels("Noise levels", "L1 Prediction Loss After Training", fontsize=15)
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
plt.legend(fontsize='large')
plt.savefig(f"{directory}/parent_vs_children_accuracy_{ST_name_hash}.png")
plt.clf()