functionalities_test.py updated

This commit is contained in:
steffen-illium
2021-05-16 14:51:21 +02:00
parent b1472479cb
commit 36377ee27d
7 changed files with 66 additions and 97 deletions

View File

@@ -1,5 +1,5 @@
from mixed_setting_exp import run_mixed_experiment
from robustness_exp import run_robustness_experiment
from self_application_exp import run_SA_experiment
from self_train_exp import run_ST_experiment
from soup_exp import run_soup_experiment
from .mixed_setting_exp import run_mixed_experiment
from .robustness_exp import run_robustness_experiment
from .self_application_exp import run_SA_experiment
from .self_train_exp import run_ST_experiment
from .soup_exp import run_soup_experiment

View File

@@ -5,7 +5,7 @@ from pathlib import Path
from visualization import line_chart_fixpoints, bar_chart_fixpoints
def summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory_name,
def summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory,
summary_pre_title):
avg_fixpoint_counters = {
"avg_identity_func": 0,
@@ -36,7 +36,7 @@ def summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_
# Plotting the summary
source_checker = "summary"
exp_details = f"{summary_pre_title}: {runs} runs & {epochs} epochs each."
bar_chart_fixpoints(avg_fixpoint_counters, population_size, directory_name, net_learning_rate, exp_details,
bar_chart_fixpoints(avg_fixpoint_counters, population_size, directory, net_learning_rate, exp_details,
source_checker)

View File

@@ -71,12 +71,10 @@ class MixedSettingExperiment:
input_data = net.input_weight_matrix()
target_data = net.create_target_weights(input_data)
net.self_train(1, self.log_step_size, self.net_learning_rate, input_data, target_data)
input_data = net.input_weight_matrix()
net.self_application(input_data, self.SA_steps, self.log_step_size)
net.self_application(self.SA_steps, self.log_step_size)
elif self.train_nets == "after_SA":
input_data = net.input_weight_matrix()
net.self_application(input_data, self.SA_steps, self.log_step_size)
net.self_application(self.SA_steps, self.log_step_size)
for _ in range(self.ST_steps_between_SA):
input_data = net.input_weight_matrix()
target_data = net.create_target_weights(input_data)

View File

@@ -1,5 +1,6 @@
import os.path
import pickle
from pathlib import Path
from tqdm import tqdm
@@ -82,13 +83,13 @@ class SelfTrainExperiment:
def run_ST_experiment(population_size, batch_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
epochs, runs, run_name, name_hash):
experiments = {}
check_folder("self_training")
logging_directory = Path('output') / 'self_training'
logging_directory.mkdir(parents=True, exist_ok=True)
# Running the experiments
for i in range(runs):
ST_directory_name = f"experiments/self_training/{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
experiment_name = f"{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
this_exp_directory = logging_directory / experiment_name
ST_experiment = SelfTrainExperiment(
population_size,
batch_size,
@@ -97,17 +98,19 @@ def run_ST_experiment(population_size, batch_size, net_input_size, net_hidden_si
net_out_size,
net_learning_rate,
epochs,
ST_directory_name
this_exp_directory
)
pickle.dump(ST_experiment, open(f"{ST_directory_name}/full_experiment_pickle.p", "wb"))
with (this_exp_directory / 'full_experiment_pickle.p').open('wb') as f:
pickle.dump(ST_experiment, f)
experiments[i] = ST_experiment
# Building a summary of all the runs
directory_name = f"experiments/self_training/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
os.mkdir(directory_name)
summary_name = f"/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
summary_directory_name = logging_directory / summary_name
summary_directory_name.mkdir(parents=True, exist_ok=True)
summary_pre_title = "ST"
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory_name,
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, summary_directory_name,
summary_pre_title)
if __name__ == '__main__':