pomdp=None and omit agent slice now working
This commit is contained in:
@ -110,12 +110,13 @@ class BaseFactory(gym.Env):
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
|
||||
if self.pomdp_radius:
|
||||
agent_slice = self.n_agents if self.omit_agent_slice_in_obs else 0
|
||||
return spaces.Box(low=0, high=1, shape=(self._state.shape[0] - agent_slice, self.pomdp_radius * 2 + 1,
|
||||
self.pomdp_radius * 2 + 1), dtype=np.float32)
|
||||
else:
|
||||
space = spaces.Box(low=0, high=1, shape=self._state.shape, dtype=np.float32)
|
||||
shape = [x-agent_slice if idx == 0 else x for idx, x in enumerate(self._state.shape)]
|
||||
space = spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
return space
|
||||
|
||||
@property
|
||||
@ -193,14 +194,9 @@ class BaseFactory(gym.Env):
|
||||
abs(a_pos[1]-self.pomdp_radius):abs(a_pos[1]-self.pomdp_radius)+obs.shape[2]] = obs
|
||||
obs = obs_padded
|
||||
else:
|
||||
assert not self.omit_agent_slice_in_obs
|
||||
obs = self._state
|
||||
if self.omit_agent_slice_in_obs:
|
||||
if obs.shape != (3, 5, 5):
|
||||
print('Shiiiiiit')
|
||||
obs_new = obs[[key for key, val in self._state_slices.items() if 'agent' not in val]]
|
||||
if obs_new.shape != self.observation_space.shape:
|
||||
print('Shiiiiiit')
|
||||
return obs_new
|
||||
else:
|
||||
return obs
|
||||
|
Reference in New Issue
Block a user