Added plot variations for basin exp.

This commit is contained in:
Maximilian Zorn
2021-05-27 16:02:41 +02:00
parent 5e5511caf8
commit 32ebb729e8
2 changed files with 70 additions and 20 deletions

View File

@ -78,3 +78,18 @@ def test_for_fixpoints(fixpoint_counter: Dict, nets: List, id_functions=None):
def changing_rate(x_new, x_old): def changing_rate(x_new, x_old):
return x_new - x_old return x_new - x_old
def test_status(net: Net) -> Net:
if is_divergent(net):
net.is_fixpoint = "divergent"
elif is_identity_function(net): # is default value
net.is_fixpoint = "identity_func"
elif is_zero_fixpoint(net):
net.is_fixpoint = "fix_zero"
elif is_secondary_fixpoint(net):
net.is_fixpoint = "fix_sec"
else:
net.is_fixpoint = "other_func"
return net

View File

@ -1,18 +1,21 @@
import os import os
from pathlib import Path from pathlib import Path
import pickle import pickle
from torch import mean
from tqdm import tqdm from tqdm import tqdm
import random import random
import copy import copy
from functionalities_test import is_identity_function from functionalities_test import is_identity_function, test_status
from network import Net from network import Net
from visualization import plot_3d_self_train, plot_loss from visualization import plot_3d_self_train, plot_loss
import numpy as np import numpy as np
from tabulate import tabulate from tabulate import tabulate
from sklearn.metrics import mean_absolute_error as MAE from sklearn.metrics import mean_absolute_error as MAE
from sklearn.metrics import mean_squared_error as MSE from sklearn.metrics import mean_squared_error as MSE
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
def prng(): def prng():
return random.random() return random.random()
@ -120,8 +123,8 @@ class SpawnExperiment:
self.spawn_and_continue() self.spawn_and_continue()
self.weights_evolution_3d_experiment() self.weights_evolution_3d_experiment()
# self.visualize_loss() # self.visualize_loss()
self.distance_matrix = distance_matrix(self.nets) self.distance_matrix = distance_matrix(self.nets, print_it=False)
self.parent_clone_distances = distance_from_parent(self.nets) self.parent_clone_distances = distance_from_parent(self.nets, print_it=False)
self.save() self.save()
@ -136,13 +139,13 @@ class SpawnExperiment:
for _ in range(self.ST_steps): for _ in range(self.ST_steps):
net.self_train(1, self.log_step_size, self.net_learning_rate) net.self_train(1, self.log_step_size, self.net_learning_rate)
# print(f"\nLast weight matrix (epoch: {self.epochs}):\n
# {net.input_weight_matrix()}\nLossHistory: {net.loss_history[-10:]}")
self.nets.append(net) self.nets.append(net)
def spawn_and_continue(self, number_clones: int = None): def spawn_and_continue(self, number_clones: int = None):
number_clones = number_clones or self.nr_clones number_clones = number_clones or self.nr_clones
df = pd.DataFrame(columns=['parent', 'MAE_pre','MAE_post', 'MSE_pre', 'MSE_post', 'MIM_pre', 'MIM_post', 'noise', 'status_post'])
# For every initial net {i} after populating (that is fixpoint after first epoch); # For every initial net {i} after populating (that is fixpoint after first epoch);
for i in range(self.population_size): for i in range(self.population_size):
net = self.nets[i] net = self.nets[i]
@ -168,26 +171,46 @@ class SpawnExperiment:
clone = self.apply_noise(clone, rand_noise) clone = self.apply_noise(clone, rand_noise)
clone.s_train_weights_history = copy.deepcopy(net.s_train_weights_history) clone.s_train_weights_history = copy.deepcopy(net.s_train_weights_history)
clone.number_trained = copy.deepcopy(net.number_trained) clone.number_trained = copy.deepcopy(net.number_trained)
# Pre Training distances (after noise application of course)
clone_pre_weights = clone.create_target_weights(clone.input_weight_matrix())
MAE_pre = MAE(net_target_data, clone_pre_weights)
MSE_pre = MSE(net_target_data, clone_pre_weights)
MIM_pre = mean_invariate_manhattan_distance(net_target_data, clone_pre_weights)
# Then finish training each clone {j} (for remaining epoch-1 * ST_steps) # Then finish training each clone {j} (for remaining epoch-1 * ST_steps) ..
# and add to nets for plotting if they are fixpoints themselves;
for _ in range(self.epochs - 1): for _ in range(self.epochs - 1):
for _ in range(self.ST_steps): for _ in range(self.ST_steps):
clone.self_train(1, self.log_step_size, self.net_learning_rate) clone.self_train(1, self.log_step_size, self.net_learning_rate)
# Post Training distances for comparison
clone_post_weights = clone.create_target_weights(clone.input_weight_matrix())
MAE_post = MAE(net_target_data, clone_post_weights)
MSE_post = MSE(net_target_data, clone_post_weights)
MIM_post = mean_invariate_manhattan_distance(net_target_data, clone_post_weights)
# .. log to data-frame and add to nets for 3d plotting if they are fixpoints themselves.
test_status(clone)
if is_identity_function(clone): if is_identity_function(clone):
input_data = clone.input_weight_matrix() print(f"Clone {j} (of net_{i}) is fixpoint."
target_data = clone.create_target_weights(input_data) f"\nMSE({i},{j}): {MSE_post}"
print(f"Clone {j} (of net_{i}) is fixpoint. \nMSE(j,i): " f"\nMAE({i},{j}): {MAE_post}"
f"{MSE(net_target_data, target_data)}, \nMAE(j,i): {MAE(net_target_data, target_data)}\n") f"\nMIM({i},{j}): {MIM_post}\n")
self.nets.append(clone) self.nets.append(clone)
df.loc[clone.name] = [net.name, MAE_pre, MAE_post, MSE_pre, MSE_post, MIM_pre, MIM_post, self.noise, clone.is_fixpoint]
# Finally take parent net {i} and finish it's training for comparison to clone development. # Finally take parent net {i} and finish it's training for comparison to clone development.
for _ in range(self.epochs - 1): for _ in range(self.epochs - 1):
for _ in range(self.ST_steps): for _ in range(self.ST_steps):
net.self_train(1, self.log_step_size, self.net_learning_rate) net.self_train(1, self.log_step_size, self.net_learning_rate)
net_weights_after = net.create_target_weights(net.input_weight_matrix())
print(f"Parent net's distance to original position."
f"\nMSE(OG,new): {MAE(net_target_data, net_weights_after)}"
f"\nMAE(OG,new): {MSE(net_target_data, net_weights_after)}"
f"\nMIM(OG,new): {mean_invariate_manhattan_distance(net_target_data, net_weights_after)}\n")
else: self.df = df
print("No fixpoints found.")
def weights_evolution_3d_experiment(self): def weights_evolution_3d_experiment(self):
exp_name = f"ST_{str(len(self.nets))}_nets_3d_weights_PCA" exp_name = f"ST_{str(len(self.nets))}_nets_3d_weights_PCA"
@ -217,15 +240,16 @@ if __name__ == "__main__":
ST_log_step_size = 10 ST_log_step_size = 10
# Define number of networks & their architecture # Define number of networks & their architecture
nr_clones = 10 nr_clones = 5
ST_population_size = 3 ST_population_size = 1
ST_net_hidden_size = 2 ST_net_hidden_size = 2
ST_net_learning_rate = 0.04 ST_net_learning_rate = 0.04
ST_name_hash = random.getrandbits(32) ST_name_hash = random.getrandbits(32)
print(f"Running the Spawn experiment:") print(f"Running the Spawn experiment:")
for noise_factor in [1]: exp_list = []
SpawnExperiment( for noise_factor in range(2,5):
exp = SpawnExperiment(
population_size=ST_population_size, population_size=ST_population_size,
log_step_size=ST_log_step_size, log_step_size=ST_log_step_size,
net_input_size=NET_INPUT_SIZE, net_input_size=NET_INPUT_SIZE,
@ -236,5 +260,16 @@ if __name__ == "__main__":
st_steps=ST_steps, st_steps=ST_steps,
nr_clones=nr_clones, nr_clones=nr_clones,
noise=pow(10, -noise_factor), noise=pow(10, -noise_factor),
directory=Path('output') / 'spawn_basin' / f'{ST_name_hash}_10e-{noise_factor}' directory=Path('output') / 'spawn_basin' / f'{ST_name_hash}' / f'10e-{noise_factor}'
) )
exp_list.append(exp)
# Boxplot with counts of nr_fixpoints, nr_other, nr_etc. on y-axis
df = pd.concat([exp.df for exp in exp_list])
sns.countplot(data=df, x="noise", hue="status_post")
plt.savefig(f"output/spawn_basin/{ST_name_hash}/fixpoint_status_countplot.png")
# Catplot (either kind="point" or "box") that shows before-after training distances to parent
mlt = df[["MIM_pre", "MIM_post", "noise"]].melt("noise", var_name="time", value_name='Average Distance')
sns.catplot(data=mlt, x="time", y="Average Distance", col="noise", kind="point", col_wrap=5, sharey=False)
plt.savefig(f"output/spawn_basin/{ST_name_hash}/clone_distance_catplot.png")