From 8810955e86d41ebcbf184060472b81297b34411b Mon Sep 17 00:00:00 2001 From: romue Date: Wed, 2 Jun 2021 14:02:17 +0200 Subject: [PATCH] removed test script from main --- main.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 42dbae5..00fcb29 100644 --- a/main.py +++ b/main.py @@ -41,6 +41,8 @@ def combine_runs(run_path: Union[str, PathLike]): value_vars=columns, var_name="Measurement", value_name="Score") df_melted = df_melted[df_melted['Episode'] % skip_n == 0] + #df_melted['Episode'] = df_melted['Episode'] * skip_n # only needed for old version + prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted) print('Plotting done.') @@ -49,16 +51,12 @@ def combine_runs(run_path: Union[str, PathLike]): if __name__ == '__main__': from stable_baselines3 import PPO, DQN, A2C - from algorithms.dqn_reg import RegDQN - dirt_props = DirtProperties() time_stamp = int(time.time()) out_path = None - combine_runs(Path(__file__).parent / 'debug_out'/ 'A2C_1622571986') - exit() - for modeL_type in [RegDQN, DQN]: + for modeL_type in [A2C, PPO, DQN]: for seed in range(5): env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,