MetaNetworks
This commit is contained in:
@@ -282,7 +282,7 @@ if __name__ == "__main__":
|
||||
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
# Catplot of children L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
|
||||
# Catplot of child_nets L1 Prediction "progress" compared to parents. Computes one round of accuracy first. If net is a parent net (not a clone), then we reset weights to timestep of cloning first (from the weight history). So 5k (end) -> 2.5k training (in this experiment, so careful with len(history)/2, this might only work here!)
|
||||
df_acc = pd.DataFrame(columns=["name", "noise", "l1_acc", "Network Type"])
|
||||
for i in range(len(exp_list)):
|
||||
noise = exp_list[i].noise
|
||||
@@ -297,10 +297,10 @@ if __name__ == "__main__":
|
||||
mse_loss = F.mse_loss(target_data, predicted_values).item()
|
||||
l1_loss = F.l1_loss(target_data, predicted_values).item()
|
||||
|
||||
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "children"]
|
||||
df_acc.loc[len(df_acc)+1] = [network.name, noise, l1_loss, "parents" if is_parent else "child_nets"]
|
||||
print("MSE:", mse_loss, "\t", "L1: ", l1_loss, "\t", network.name)
|
||||
|
||||
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and children accuracy is too far apart this plot might not work (only shows either parents or part of the children).
|
||||
# Note: If there are outliers then showfliers=False is necessary or it will zoom way to far out. If parent and child_nets accuracy is too far apart this plot might not work (only shows either parents or part of the child_nets).
|
||||
ax = sns.catplot(data=df_acc, y="l1_acc", x="noise", hue="Network Type", kind="box", legend=False, showfliers=False, height=5.27, aspect=11.7/5.27, sharey=False)
|
||||
ax.map(plt.axhline, y=10**-6, ls='--')
|
||||
ax.map(plt.axhline, y=10**-7, ls='--')
|
||||
|
||||
Reference in New Issue
Block a user