functionalities_test.py updated
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from mixed_setting_exp import run_mixed_experiment
|
||||
from robustness_exp import run_robustness_experiment
|
||||
from self_application_exp import run_SA_experiment
|
||||
from self_train_exp import run_ST_experiment
|
||||
from soup_exp import run_soup_experiment
|
||||
from .mixed_setting_exp import run_mixed_experiment
|
||||
from .robustness_exp import run_robustness_experiment
|
||||
from .self_application_exp import run_SA_experiment
|
||||
from .self_train_exp import run_ST_experiment
|
||||
from .soup_exp import run_soup_experiment
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
from visualization import line_chart_fixpoints, bar_chart_fixpoints
|
||||
|
||||
|
||||
def summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory_name,
|
||||
def summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory,
|
||||
summary_pre_title):
|
||||
avg_fixpoint_counters = {
|
||||
"avg_identity_func": 0,
|
||||
@@ -36,7 +36,7 @@ def summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_
|
||||
# Plotting the summary
|
||||
source_checker = "summary"
|
||||
exp_details = f"{summary_pre_title}: {runs} runs & {epochs} epochs each."
|
||||
bar_chart_fixpoints(avg_fixpoint_counters, population_size, directory_name, net_learning_rate, exp_details,
|
||||
bar_chart_fixpoints(avg_fixpoint_counters, population_size, directory, net_learning_rate, exp_details,
|
||||
source_checker)
|
||||
|
||||
|
||||
|
@@ -71,12 +71,10 @@ class MixedSettingExperiment:
|
||||
input_data = net.input_weight_matrix()
|
||||
target_data = net.create_target_weights(input_data)
|
||||
net.self_train(1, self.log_step_size, self.net_learning_rate, input_data, target_data)
|
||||
input_data = net.input_weight_matrix()
|
||||
net.self_application(input_data, self.SA_steps, self.log_step_size)
|
||||
net.self_application(self.SA_steps, self.log_step_size)
|
||||
|
||||
elif self.train_nets == "after_SA":
|
||||
input_data = net.input_weight_matrix()
|
||||
net.self_application(input_data, self.SA_steps, self.log_step_size)
|
||||
net.self_application(self.SA_steps, self.log_step_size)
|
||||
for _ in range(self.ST_steps_between_SA):
|
||||
input_data = net.input_weight_matrix()
|
||||
target_data = net.create_target_weights(input_data)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import os.path
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -82,13 +83,13 @@ class SelfTrainExperiment:
|
||||
def run_ST_experiment(population_size, batch_size, net_input_size, net_hidden_size, net_out_size, net_learning_rate,
|
||||
epochs, runs, run_name, name_hash):
|
||||
experiments = {}
|
||||
|
||||
check_folder("self_training")
|
||||
logging_directory = Path('output') / 'self_training'
|
||||
logging_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Running the experiments
|
||||
for i in range(runs):
|
||||
ST_directory_name = f"experiments/self_training/{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
|
||||
experiment_name = f"{run_name}_run_{i}_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
this_exp_directory = logging_directory / experiment_name
|
||||
ST_experiment = SelfTrainExperiment(
|
||||
population_size,
|
||||
batch_size,
|
||||
@@ -97,17 +98,19 @@ def run_ST_experiment(population_size, batch_size, net_input_size, net_hidden_si
|
||||
net_out_size,
|
||||
net_learning_rate,
|
||||
epochs,
|
||||
ST_directory_name
|
||||
this_exp_directory
|
||||
)
|
||||
pickle.dump(ST_experiment, open(f"{ST_directory_name}/full_experiment_pickle.p", "wb"))
|
||||
with (this_exp_directory / 'full_experiment_pickle.p').open('wb') as f:
|
||||
pickle.dump(ST_experiment, f)
|
||||
experiments[i] = ST_experiment
|
||||
|
||||
# Building a summary of all the runs
|
||||
directory_name = f"experiments/self_training/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
os.mkdir(directory_name)
|
||||
summary_name = f"/summary_{run_name}_{runs}_runs_{str(population_size)}_nets_{epochs}_epochs_{str(name_hash)}"
|
||||
summary_directory_name = logging_directory / summary_name
|
||||
summary_directory_name.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
summary_pre_title = "ST"
|
||||
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, directory_name,
|
||||
summary_fixpoint_experiment(runs, population_size, epochs, experiments, net_learning_rate, summary_directory_name,
|
||||
summary_pre_title)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user