mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-13 22:44:00 +02:00
error in pomdp corrected
This commit is contained in:
@@ -199,8 +199,8 @@ class BaseFactory(gym.Env):
|
|||||||
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
|
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
|
||||||
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
|
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
|
||||||
obs = self._state[:, x0:x1, y0:y1]
|
obs = self._state[:, x0:x1, y0:y1]
|
||||||
if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1:
|
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
|
||||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
|
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
|
||||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||||
obs_padded[:,
|
obs_padded[:,
|
||||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||||
|
8
main.py
8
main.py
@@ -89,7 +89,7 @@ if __name__ == '__main__':
|
|||||||
# exit()
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
from algorithms.dqn_reg import RegDQN
|
from algorithms.reg_dqn import RegDQN
|
||||||
# from sb3_contrib import QRDQN
|
# from sb3_contrib import QRDQN
|
||||||
|
|
||||||
dirt_props = DirtProperties()
|
dirt_props = DirtProperties()
|
||||||
@@ -100,10 +100,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
out_path = None
|
out_path = None
|
||||||
|
|
||||||
for modeL_type in [PPO]: # , A2C, RegDQN, DQN]:
|
for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
|
||||||
movement_properties=move_props, level='rooms',
|
movement_properties=move_props, level='rooms',
|
||||||
omit_agent_slice_in_obs=True)
|
omit_agent_slice_in_obs=True)
|
||||||
|
|
||||||
@@ -112,7 +112,7 @@ if __name__ == '__main__':
|
|||||||
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
||||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||||
|
|
||||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
out_path = Path('debug_out') / f'{modeL_type.__class__.__name__}_{time_stamp}'
|
||||||
|
|
||||||
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||||
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
|
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
|
||||||
|
Reference in New Issue
Block a user