mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
pomdp=None and omit agent slice now working
This commit is contained in:
parent
5668f5cb82
commit
1e87c4807f
@ -110,12 +110,13 @@ class BaseFactory(gym.Env):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_space(self):
|
def observation_space(self):
|
||||||
|
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
|
||||||
if self.pomdp_radius:
|
if self.pomdp_radius:
|
||||||
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
|
|
||||||
return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1,
|
return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1,
|
||||||
self.pomdp_radius * 2 + 1), dtype=np.float32)
|
self.pomdp_radius * 2 + 1), dtype=np.float32)
|
||||||
else:
|
else:
|
||||||
space = spaces.Box(low=0, high=1, shape=self._state.shape, dtype=np.float32)
|
shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._state.shape)]
|
||||||
|
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||||
return space
|
return space
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -193,14 +194,9 @@ class BaseFactory(gym.Env):
|
|||||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||||
obs = obs_padded
|
obs = obs_padded
|
||||||
else:
|
else:
|
||||||
assert not self.omit_agent_slice_in_obs
|
|
||||||
obs = self._state
|
obs = self._state
|
||||||
if self.omit_agent_slice_in_obs:
|
if self.omit_agent_slice_in_obs:
|
||||||
if obs.shape != (3, 5, 5):
|
|
||||||
print('Shiiiiiit')
|
|
||||||
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
|
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
|
||||||
if obs_new.shape != self.observation_space.shape:
|
|
||||||
print('Shiiiiiit')
|
|
||||||
return obs_new
|
return obs_new
|
||||||
else:
|
else:
|
||||||
return obs
|
return obs
|
||||||
|
6
main.py
6
main.py
@ -61,7 +61,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
monitor_df = pickle.load(f)
|
monitor_df = pickle.load(f)
|
||||||
|
|
||||||
monitor_df['run'] = run
|
monitor_df['run'] = run
|
||||||
monitor_df['model'] = path.name.split('_')[0]
|
monitor_df['model'] = path.name.split('_')[1]
|
||||||
monitor_df = monitor_df.fillna(0)
|
monitor_df = monitor_df.fillna(0)
|
||||||
df_list.append(monitor_df)
|
df_list.append(monitor_df)
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ def compare_runs(run_path: Path, run_identifier: int, parameter: Union[str, List
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
# compare_runs(Path('debug_out'), 1622650432, 'step_reward')
|
# compare_runs(Path('debug_out') / 'PPO_1622800949', 1622800949, 'step_reward')
|
||||||
# exit()
|
# exit()
|
||||||
|
|
||||||
from stable_baselines3 import PPO, DQN, A2C
|
from stable_baselines3 import PPO, DQN, A2C
|
||||||
@ -103,7 +103,7 @@ if __name__ == '__main__':
|
|||||||
for coef in [0.01, 0.1, 0.25]:
|
for coef in [0.01, 0.1, 0.25]:
|
||||||
for seed in range(3):
|
for seed in range(3):
|
||||||
|
|
||||||
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=2, max_steps=400,
|
env = SimpleFactory(n_agents=1, dirt_properties=dirt_props, pomdp_radius=None, max_steps=400,
|
||||||
allow_diagonal_movement=True, allow_no_op=False, verbose=False,
|
allow_diagonal_movement=True, allow_no_op=False, verbose=False,
|
||||||
omit_agent_slice_in_obs=True)
|
omit_agent_slice_in_obs=True)
|
||||||
env.save_params(Path('debug_out', 'yaml.txt'))
|
env.save_params(Path('debug_out', 'yaml.txt'))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user