From 1e87c4807f39ecc7f629ba32da032678e23aa534 Mon Sep 17 00:00:00 2001 From: steffen-illium Date: Fri, 4 Jun 2021 15:42:31 +0200 Subject: [PATCH] pomdp=None and omit agent slice now working --- environments/factory/base_factory.py | 10 +++------- main.py | 6 +++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index 4954c69..850f794 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -110,12 +110,13 @@ class BaseFactory(gym.Env): @property def observation_space(self): + agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0 if self.pomdp_radius: - agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0 return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), dtype=np.float32) else: - space = spaces.Box(low=0, high=1, shape=self._state.shape, dtype=np.float32) + shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._state.shape)] + space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) return space @property @@ -193,14 +194,9 @@ class BaseFactory(gym.Env): abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs obs = obs_padded else: - assert not self.omit_agent_slice_in_obs obs = self._state if self.omit_agent_slice_in_obs: - if obs.shape != (3, 5, 5): - print('Shiiiiiit') obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]] - if obs_new.shape != self.observation_space.shape: - print('Shiiiiiit') return obs_new else: return obs diff --git a/main.py b/main.py index 87d808b..89c9781 100644 --- a/main.py +++ b/main.py @@ -61,7 +61,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List monitor_df = pickle.load(f) monitor_df['run'] = run - monitor_df['model'] = path.name.split('_')[0] + monitor_df['model'] = path.name.split('_')[1] monitor_df = monitor_df.fillna(0) df_list.append(monitor_df) @@ -86,7 +86,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List if __name__ == '__main__': - # compare_runs(Path('debug_out'), 1622650432, 'step_reward') + # compare_runs(Path('debug_out') / 'PPO_1622800949', 1622800949, 'step_reward') # exit() from stable_baselines3 import PPO, DQN, A2C @@ -103,7 +103,7 @@ if __name__ == '__main__': for coef in [0.01, 0.1, 0.25]: for seed in range(3): - env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400, + env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, allow_diagonal_movement=True, allow_no_op=False, verbose=False, omit_agent_slice_in_obs=True) env.save_params(Path('debug_out', 'yaml.txt'))