Added parent-vs-children plot, changed x-tick labels to base10 notation.

This commit is contained in:
Maximilian Zorn 2021-09-07 18:15:06 +02:00
parent 6c1a964f31
commit e51d7ad0b9

View File

@ -14,6 +14,8 @@ from sklearn.metrics import mean_squared_error as MSE
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import torch
import torch.nn.functional as F
def prng():
return random.random()
@ -259,10 +261,12 @@ if __name__ == "__main__":
df["relative_distance"] = [ (df.loc[i]["MAE_pre"] - df.loc[i]["MAE_post"])/df.loc[i]["noise"] for i in range(len(df))]
df["class"] = [ "approaching" if df.loc[i]["relative_distance"] > 0 else "distancing" if df.loc[i]["relative_distance"] < 0 else "stationary" for i in range(len(df))]
# Countplot of all fixpoint clone after training per class. Uncomment and manually adjust xticklabels if x-ax size gets too small.
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27)
# Countplot of all fixpoint clone after training per class.
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27, legend=False)
ax.set_axis_labels("Noise Levels", "Clone Fixpoints After Training Count ", fontsize=15)
#ax.set_xticklabels(labels=('10e-10', '10e-9', '10e-8', '10e-7', '10e-6', '10e-5', '10e-4', '10e-3', '10e-2', '10e-1'), fontsize=15)
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
plt.legend(fontsize='large')
plt.savefig(f"{directory}/clone_status_after_countplot_{ST_name_hash}.png")
plt.clf()
@ -276,3 +280,32 @@ if __name__ == "__main__":
ax.set_xticklabels(labels=('after noise application', 'after training'), fontsize=15)
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
plt.clf()
# Catplot of children L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
df_acc = pd.DataFrame(columns=["name", "noise", "l1_acc", "Network Type"])
for i in range(len(exp_list)):
noise = exp_list[i].noise
print(f"\nNoise: {noise}")
for network in exp_list[i].nets:
is_parent = "clone" not in network.name
if is_parent:
network.apply_weights(torch.tensor(network.s_train_weights_history[int(len(network.s_train_weights_history)/2)][0]))
input_data = network.input_weight_matrix()
target_data = network.create_target_weights(input_data)
predicted_values = network(input_data)
mse_loss = F.mse_loss(target_data, predicted_values).item()
l1_loss = F.l1_loss(target_data, predicted_values).item()
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "children"]
print("MSE:", mse_loss, "\t", "L1: ", l1_loss, "\t", network.name)
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and children accuracy is too far apart this plot might not work (only shows either parents or part of the children).
ax = sns.catplot(data=df_acc, y="l1_acc", x="noise", hue="Network Type", kind="box", legend=False, showfliers=False, height=5.27, aspect=11.7/5.27, sharey=False)
ax.map(plt.axhline, y=10**-6, ls='--')
ax.map(plt.axhline, y=10**-7, ls='--')
ax.set_axis_labels("Noise levels", "L1 Prediction Loss After Training", fontsize=15)
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
plt.legend(fontsize='large')
plt.savefig(f"{directory}/parent_vs_children_accuracy_{ST_name_hash}.png")
plt.clf()