DQN Monitor fixed

This commit is contained in:
steffen-illium 2021-06-02 13:36:20 +02:00
parent 017a94d6b7
commit d8e6bfc9a9
2 changed files with 5 additions and 4 deletions

View File

@ -72,7 +72,8 @@ class MonitorCallback(BaseCallback):
self._monitor_dict[self.num_timesteps] = {key: val for key, val in info.items()
if key not in ['terminal_observation', 'episode']}
for env_idx, done in enumerate(self.locals.get('dones', [])):
for env_idx, done in list(enumerate(self.locals.get('dones', []))) + \
list(enumerate(self.locals.get('done', []))):
if done:
env_monitor_df = pd.DataFrame.from_dict(self._monitor_dict, orient='index')
self._monitor_dict = dict()

View File

@ -41,8 +41,6 @@ def combine_runs(run_path: Union[str, PathLike]):
value_vars=columns, var_name="Measurement",
value_name="Score")
df_melted = df_melted[df_melted['Episode'] % skip_n == 0]
#df_melted['Episode'] = df_melted['Episode'] * skip_n # only needed for old version
prepare_plot(run_path / f'{run_path.name}_monitor_lineplot.png', df_melted)
print('Plotting done.')
@ -51,12 +49,14 @@ def combine_runs(run_path: Union[str, PathLike]):
if __name__ == '__main__':
from stable_baselines3 import PPO, DQN, A2C
from algorithms.dqn_reg import RegDQN
dirt_props = DirtProperties()
time_stamp = int(time.time())
out_path = None
for modeL_type in [A2C, PPO, DQN]:
for modeL_type in [RegDQN, DQN]:
for seed in range(5):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,