mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-09-13 22:44:00 +02:00
error in pomdp corrected
This commit is contained in:
@@ -199,8 +199,8 @@ class BaseFactory(gym.Env):
|
||||
x0, x1 = max(0, pos[0] - self.pomdp_radius), pos[0] + self.pomdp_radius + 1
|
||||
y0, y1 = max(0, pos[1] - self.pomdp_radius), pos[1] + self.pomdp_radius + 1
|
||||
obs = self._state[:, x0:x1, y0:y1]
|
||||
if obs.shape[1] != self.pomdp_radius ** 2 + 1 or obs.shape[2] != self.pomdp_radius ** 2 + 1:
|
||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius ** 2 + 1, self.pomdp_radius ** 2 + 1), 1)
|
||||
if obs.shape[1] != self.pomdp_radius * 2 + 1 or obs.shape[2] != self.pomdp_radius * 2 + 1:
|
||||
obs_padded = np.full((obs.shape[0], self.pomdp_radius * 2 + 1, self.pomdp_radius * 2 + 1), 1)
|
||||
a_pos = np.argwhere(obs[h.AGENT_START_IDX] == h.IS_OCCUPIED_CELL)[0]
|
||||
obs_padded[:,
|
||||
abs(a_pos[0]-self.pomdp_radius):abs(a_pos[0]-self.pomdp_radius)+obs.shape[1],
|
||||
|
8
main.py
8
main.py
@@ -89,7 +89,7 @@ if __name__ == '__main__':
|
||||
# exit()
|
||||
|
||||
from stable_baselines3 import PPO, DQN, A2C
|
||||
from algorithms.dqn_reg import RegDQN
|
||||
from algorithms.reg_dqn import RegDQN
|
||||
# from sb3_contrib import QRDQN
|
||||
|
||||
dirt_props = DirtProperties()
|
||||
@@ -100,10 +100,10 @@ if __name__ == '__main__':
|
||||
|
||||
out_path = None
|
||||
|
||||
for modeL_type in [PPO]: # , A2C, RegDQN, DQN]:
|
||||
for modeL_type in [PPO, A2C, RegDQN, DQN]:
|
||||
for seed in range(3):
|
||||
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=3, max_steps=400,
|
||||
movement_properties=move_props, level='rooms',
|
||||
omit_agent_slice_in_obs=True)
|
||||
|
||||
@@ -112,7 +112,7 @@ if __name__ == '__main__':
|
||||
kwargs = dict(ent_coef=0.01) if isinstance(modeL_type, (PPO, A2C)) else {}
|
||||
model = modeL_type("MlpPolicy", env, verbose=1, seed=seed, device='cpu', **kwargs)
|
||||
|
||||
out_path = Path('debug_out') / f'{model.__class__.__name__}_{time_stamp}'
|
||||
out_path = Path('debug_out') / f'{modeL_type.__class__.__name__}_{time_stamp}'
|
||||
|
||||
# identifier = f'{seed}_{model.__class__.__name__}_{time_stamp}'
|
||||
identifier = f'{seed}_{modeL_type.__class__.__name__}_{time_stamp}'
|
||||
|
Reference in New Issue
Block a user