Fixed exp pickle save. Added exp results. Fixed soup plot issue.
This commit is contained in:
@@ -126,7 +126,7 @@ class SpawnExperiment:
|
||||
# self.visualize_loss()
|
||||
self.distance_matrix = distance_matrix(self.nets, print_it=False)
|
||||
self.parent_clone_distances = distance_from_parent(self.nets, print_it=False)
|
||||
self.save()
|
||||
|
||||
|
||||
def populate_environment(self):
|
||||
loop_population_size = tqdm(range(self.population_size))
|
||||
@@ -155,7 +155,7 @@ class SpawnExperiment:
|
||||
# We set parent start_time to just before this epoch ended, so plotting is zoomed in. Comment out to
|
||||
# to see full trajectory (but the clones will be very hard to see).
|
||||
# Make one target to compare distances to clones later when they have trained.
|
||||
net.start_time = self.ST_steps - 150
|
||||
net.start_time = self.ST_steps - 350
|
||||
net_input_data = net.input_weight_matrix()
|
||||
net_target_data = net.create_target_weights(net_input_data)
|
||||
|
||||
@@ -169,7 +169,7 @@ class SpawnExperiment:
|
||||
for j in range(number_clones):
|
||||
clone = Net(net.input_size, net.hidden_size, net.out_size,
|
||||
f"ST_net_{str(i)}_clone_{str(j)}", start_time=self.ST_steps)
|
||||
clone.load_state_dict(copy.deepcopy(net.state_dict()))
|
||||
clone.load_state_dict(copy.deepcopy(net.state_dict()))
|
||||
rand_noise = prng() * self.noise
|
||||
clone = self.apply_noise(clone, rand_noise)
|
||||
clone.s_train_weights_history = copy.deepcopy(net.s_train_weights_history)
|
||||
@@ -225,9 +225,6 @@ class SpawnExperiment:
|
||||
self.loss_history.append(net_loss_history)
|
||||
plot_loss(self.loss_history, self.directory)
|
||||
|
||||
def save(self):
|
||||
pickle.dump(self, open(f"{self.directory}/experiment_pickle.p", "wb"))
|
||||
print(f"\nSaved experiment to {self.directory}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -243,15 +240,15 @@ if __name__ == "__main__":
|
||||
ST_log_step_size = 10
|
||||
|
||||
# Define number of networks & their architecture
|
||||
nr_clones = 5
|
||||
ST_population_size = 2
|
||||
nr_clones = 10
|
||||
ST_population_size = 1
|
||||
ST_net_hidden_size = 2
|
||||
ST_net_learning_rate = 0.04
|
||||
ST_name_hash = random.getrandbits(32)
|
||||
|
||||
print(f"Running the Spawn experiment:")
|
||||
exp_list = []
|
||||
for noise_factor in range(2, 4):
|
||||
for noise_factor in range(2, 3):
|
||||
exp = SpawnExperiment(
|
||||
population_size=ST_population_size,
|
||||
log_step_size=ST_log_step_size,
|
||||
@@ -267,18 +264,30 @@ if __name__ == "__main__":
|
||||
)
|
||||
exp_list.append(exp)
|
||||
|
||||
# Boxplot with counts of nr_fixpoints, nr_other, nr_etc. on y-axis
|
||||
directory = Path('output') / 'spawn_basin' / f'{ST_name_hash}'
|
||||
pickle.dump(exp_list, open(f"{directory}/experiment_pickle_{ST_name_hash}.p", "wb"))
|
||||
print(f"\nSaved experiment to {directory}.")
|
||||
|
||||
# Concat all dataframes, and add columns depending on where clone weights end up after training (rel. to parent)
|
||||
df = pd.concat([exp.df for exp in exp_list])
|
||||
sns.countplot(data=df, x="noise", hue="status_post")
|
||||
plt.savefig(f"output/spawn_basin/{ST_name_hash}/fixpoint_status_countplot.png")
|
||||
df = df.dropna().reset_index()
|
||||
df["relative_distance"] = [ (df.loc[i]["MAE_pre"] - df.loc[i]["MAE_post"])/df.loc[i]["noise"] for i in range(len(df))]
|
||||
df["class"] = [ "approaching" if df.loc[i]["relative_distance"] > 0 else "distancing" if df.loc[i]["relative_distance"] < 0 else "stationary" for i in range(len(df))]
|
||||
|
||||
# Catplot (either kind="point" or "box") that shows before-after training distances to parent
|
||||
mlt = df[["MIM_pre", "MIM_post", "noise"]].melt("noise", var_name="time", value_name='Average Distance')
|
||||
sns.catplot(data=mlt, x="time", y="Average Distance", col="noise", kind="point", col_wrap=5, sharey=False)
|
||||
plt.savefig(f"output/spawn_basin/{ST_name_hash}/clone_distance_catplot.png")
|
||||
|
||||
mlt = df.melt(id_vars=["name", "noise"], value_vars=["MAE_pre", "MAE_post"], var_name="State", value_name="Distance")
|
||||
ax = sns.catplot(data=mlt, x="State", y="Distance", col="noise", hue="name", kind="point", sharey=False, palette="Greens", legend=False)
|
||||
ax.map(sns.boxplot, "State", "Distance", "noise", linewidth=0.8, order=["MAE_pre", "MAE_post"])
|
||||
plt.savefig(f"output/spawn_basin/{ST_name_hash}/before_after_distance_catplot.png")
|
||||
# Countplot of all fixpoint clone after training per class. Uncomment and manually adjust xticklabels if x-ax size gets too small.
|
||||
ax = sns.catplot(kind="count", data=df, x="noise", hue="class", height=5.27, aspect=11.7/5.27)
|
||||
ax.set_axis_labels("Noise Levels", "Clone Fixpoints After Training Count ", fontsize=15)
|
||||
#ax.set_xticklabels(labels=('10e-10', '10e-9', '10e-8', '10e-7', '10e-6', '10e-5', '10e-4', '10e-3', '10e-2', '10e-1'), fontsize=15)
|
||||
plt.savefig(f"{directory}/clone_status_after_countplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
# Catplot of before-after comparison of the clone's weights. Colors links depending on class (approaching, distancing, stationary (i.e., MAE=0)). Blue, orange and green are based on countplot above, should be save for colorblindness (see https://gist.github.com/mwaskom/b35f6ebc2d4b340b4f64a4e28e778486)-
|
||||
mlt = df.melt(id_vars=["name", "noise", "class"], value_vars=["MAE_pre", "MAE_post"], var_name="State", value_name="Distance")
|
||||
P = ["blue" if mlt.loc[i]["class"] == "approaching" else "orange" if mlt.loc[i]["class"] == "distancing" else "green" for i in range(len(mlt))]
|
||||
P = sns.color_palette(P, as_cmap=False)
|
||||
ax = sns.catplot(data=mlt, x="State", y="Distance", col="noise", hue="name", kind="point", palette=P, col_wrap=min(5, len(exp_list)), sharey=False, legend=False)
|
||||
ax.map(sns.boxplot, "State", "Distance", "noise", linewidth=0.8, order=["MAE_pre", "MAE_post"], whis=[0, 100])
|
||||
ax.set_axis_labels("", "Manhattan Distance To Parent Weights", fontsize=15)
|
||||
ax.set_xticklabels(labels=('after noise application', 'after training'), fontsize=15)
|
||||
plt.savefig(f"{directory}/before_after_distance_catplot_{ST_name_hash}.png")
|
||||
plt.clf()
|
||||
|
||||
Reference in New Issue
Block a user