diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index 2b1c360..3331405 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -281,7 +281,7 @@ class BaseFactory(gym.Env): else: obs = self._obs_cube - if self.combin_agent_slices_in_obs and self.n_agents >= 1: + if self.combin_agent_slices_in_obs and self.n_agents > 1: agent_obs = np.sum(obs[[key for key, slice in self._slices.items() if c.AGENT.name in slice.name and (not self.omit_agent_slice_in_obs and slice.name != agent.name)]], axis=0, keepdims=True)