diff --git a/environments/factory/base_factory.py b/environments/factory/base_factory.py index bb0135a..00cb173 100644 --- a/environments/factory/base_factory.py +++ b/environments/factory/base_factory.py @@ -199,8 +199,8 @@ class BaseFactory(gym.Env): x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1 y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1 obs = self._state[:, x0:x1, y0:y1] - if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1: - obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1) + if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1: + obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1) a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0] obs_padded[:, abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1], diff --git a/main.py b/main.py index f05fd7f..ff88543 100644 --- a/main.py +++ b/main.py @@ -89,7 +89,7 @@ if __name__ == '__main__': # exit() from stable_baselines3 import PPO, DQN, A2C - from algorithms.dqn_reg import RegDQN + from algorithms.reg_dqn import RegDQN # from sb3_contrib import QRDQN dirt_props = DirtProperties() @@ -100,10 +100,10 @@ if __name__ == '__main__': out_path = None - for modeL_type in [PPO]: # , A2C, RegDQN, DQN]: + for modeL_type in [PPO, A2C, RegDQN, DQN]: for seed in range(3): - env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400, + env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400, movement_properties=move_props, level='rooms', omit_agent_slice_in_obs=True) @@ -112,7 +112,7 @@ if __name__ == '__main__': kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {} model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs) - out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}' + out_path = Path('debug_out') / f'{modeL_type.__class__.__name__}_{time_stamp}' # identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}' identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'