error in pomdp corrected

This commit is contained in:
steffen-illium
2021-06-04 19:20:03 +02:00
parent d251410e0a
commit 4862407526
2 changed files with 6 additions and 6 deletions

View File

@@ -199,8 +199,8 @@ class BaseFactory(gym.Env):
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
obs = self._state[:, x0:x1, y0:y1]
if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1:
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
obs_padded[:,
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],

View File

@@ -89,7 +89,7 @@ if __name__ == '__main__':
# exit()
from stable_baselines3 import PPO, DQN, A2C
from algorithms.dqn_reg import RegDQN
from algorithms.reg_dqn import RegDQN
# from sb3_contrib import QRDQN
dirt_props = DirtProperties()
@@ -100,10 +100,10 @@ if __name__ == '__main__':
out_path = None
for modeL_type in [PPO]: # , A2C, RegDQN, DQN]:
for modeL_type in [PPO, A2C, RegDQN, DQN]:
for seed in range(3):
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
movement_properties=move_props, level='rooms',
omit_agent_slice_in_obs=True)
@@ -112,7 +112,7 @@ if __name__ == '__main__':
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
out_path = Path('debug_out') / f'{modeL_type.__class__.__name__}_{time_stamp}'
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'