Added parent-vs-children plot, changed x-tick labels to base10 notation.
This commit is contained in:
parent
6c1a964f31
commit
e51d7ad0b9
@ -14,6 +14,8 @@ from sklearn.metrics import mean_squared_error as MSE
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib import pyplot as plt
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def prng():
|
||||
return random.random()
|
||||
@ -259,10 +261,12 @@ if __name__ == "__main__":
|
||||
df["relative_distance"] = [ (df.loc[i]["MAE_pre"] - df.loc[i]["MAE_post"])/df.loc[i]["noise"] for i in range(len(df))]
|
||||
df["class"] = [ "approaching" if df.loc[i]["relative_distance"] > 0 else "distancing" if df.loc[i]["relative_distance"] < 0 else "stationary" for i in range(len(df))]
|
||||
|
||||
# Countplot of all fixpoint clone after training per class. Uncomment and manually adjust xticklabels if x-ax size gets too small.
|
||||
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27)
|
||||
# Countplot of all fixpoint clone after training per class.
|
||||
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27, legend=False)
|
||||
ax.set_axis_labels("Noise Levels", "Clone Fixpoints After Training Count ", fontsize=15)
|
||||
#ax.set_xticklabels(labels=('10e-10', '10e-9', '10e-8', '10e-7', '10e-6', '10e-5', '10e-4', '10e-3', '10e-2', '10e-1'), fontsize=15)
|
||||
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
|
||||
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
|
||||
plt.legend(fontsize='large')
|
||||
plt.savefig(f"{directory}/clone_status_after_countplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
@ -276,3 +280,32 @@ if __name__ == "__main__":
|
||||
ax.set_xticklabels(labels=('after noise application', 'after training'), fontsize=15)
|
||||
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
# Catplot of children L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
|
||||
df_acc = pd.DataFrame(columns=["name", "noise", "l1_acc", "Network Type"])
|
||||
for i in range(len(exp_list)):
|
||||
noise = exp_list[i].noise
|
||||
print(f"\nNoise: {noise}")
|
||||
for network in exp_list[i].nets:
|
||||
is_parent = "clone" not in network.name
|
||||
if is_parent:
|
||||
network.apply_weights(torch.tensor(network.s_train_weights_history[int(len(network.s_train_weights_history)/2)][0]))
|
||||
input_data = network.input_weight_matrix()
|
||||
target_data = network.create_target_weights(input_data)
|
||||
predicted_values = network(input_data)
|
||||
mse_loss = F.mse_loss(target_data, predicted_values).item()
|
||||
l1_loss = F.l1_loss(target_data, predicted_values).item()
|
||||
|
||||
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "children"]
|
||||
print("MSE:", mse_loss, "\t", "L1: ", l1_loss, "\t", network.name)
|
||||
|
||||
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and children accuracy is too far apart this plot might not work (only shows either parents or part of the children).
|
||||
ax = sns.catplot(data=df_acc, y="l1_acc", x="noise", hue="Network Type", kind="box", legend=False, showfliers=False, height=5.27, aspect=11.7/5.27, sharey=False)
|
||||
ax.map(plt.axhline, y=10**-6, ls='--')
|
||||
ax.map(plt.axhline, y=10**-7, ls='--')
|
||||
ax.set_axis_labels("Noise levels", "L1 Prediction Loss After Training", fontsize=15)
|
||||
ax.set_xticklabels(labels=('$\mathregular{10^{-10}}$', '$\mathregular{10^{-9}}$', '$\mathregular{10^{-8}}$', '$\mathregular{10^{-7}}$', '$\mathregular{10^{-6}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-4}}$', '$\mathregular{10^{-5}}$', '$\mathregular{10^{-2}}$', '$\mathregular{10^{-1}}$'), fontsize=15)
|
||||
plt.legend(bbox_to_anchor=(0.01, 0.85), loc=2, borderaxespad=0.)
|
||||
plt.legend(fontsize='large')
|
||||
plt.savefig(f"{directory}/parent_vs_children_accuracy_{ST_name_hash}.png")
|
||||
plt.clf()
|
Loading…
x
Reference in New Issue
Block a user