Debugged Item Factory

This commit is contained in:
Steffen Illium
2021-09-08 11:06:47 +02:00
parent 50c0d90c77
commit b09055d95d
6 changed files with 91 additions and 29 deletions

View File

@ -128,7 +128,7 @@ class BaseFactory(gym.Env):
parsed_doors = h.one_hot_level(parsed_level, c.DOOR)
if np.any(parsed_doors):
door_tiles = [floor.by_pos(pos) for pos in np.argwhere(parsed_doors == c.OCCUPIED_CELL.value)]
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor, is_blocking_light=True)
doors = Doors.from_tiles(door_tiles, self._level_shape, context=floor)
entities.update({c.DOORS: doors})
# Actions
@ -137,7 +137,8 @@ class BaseFactory(gym.Env):
self._actions.register_additional_items(additional_actions)
# Agents
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape)
agents = Agents.from_tiles(floor.empty_tiles[:self.n_agents], self._level_shape,
individual_slices=not self.combin_agent_obs)
entities.update({c.AGENT: agents})
# All entities
@ -152,10 +153,12 @@ class BaseFactory(gym.Env):
return self._entities
def _init_obs_cube(self):
arrays = self._entities.arrays
arrays = self._entities.observable_arrays
if self.omit_agent_in_obs and self.n_agents == 1:
del arrays[c.AGENT]
elif self.omit_agent_in_obs:
arrays[c.AGENT] = np.delete(arrays[c.AGENT], 0, axis=0)
obs_cube_z = sum([a.shape[0] if not self[key].is_per_agent else 1 for key, a in arrays.items()])
self._obs_cube = np.zeros((obs_cube_z, *self._level_shape), dtype=np.float32)
@ -257,7 +260,7 @@ class BaseFactory(gym.Env):
return c.NOT_VALID
def _get_observations(self) -> np.ndarray:
state_array_dict = self._entities.arrays
state_array_dict = self._entities.obs_arrays
if self.n_agents == 1:
obs = self._build_per_agent_obs(self[c.AGENT][0], state_array_dict)
elif self.n_agents >= 2:
@ -268,11 +271,14 @@ class BaseFactory(gym.Env):
def _build_per_agent_obs(self, agent: Agent, state_array_dict) -> np.ndarray:
agent_pos_is_omitted = False
agent_omit_idx = None
if self.omit_agent_in_obs and self.n_agents == 1:
del state_array_dict[c.AGENT]
elif self.omit_agent_in_obs and self.combin_agent_obs and self.n_agents > 1:
state_array_dict[c.AGENT][0, agent.x, agent.y] -= agent.encoding
agent_pos_is_omitted = True
elif self.omit_agent_in_obs and not self.combin_agent_obs and self.n_agents > 1:
agent_omit_idx = next((i for i, a in enumerate(self[c.AGENT]) if a == agent))
running_idx, shadowing_idxs, can_be_shadowed_idxs = 0, [], []
@ -284,8 +290,14 @@ class BaseFactory(gym.Env):
z = 1
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
if key == c.AGENT and agent_omit_idx is not None:
z = array.shape[0] - 1
for array_idx in range(array.shape[0]):
self._obs_cube[running_idx: running_idx+z] = array[[x for x in range(array.shape[0])
if x != agent_omit_idx]]
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
# Define which OBS SLices cast a Shadow
if self[key].is_blocking_light:
for i in range(z):
@ -345,9 +357,13 @@ class BaseFactory(gym.Env):
else:
pass
# Additional Observation:
for additional_obs in self.additional_obs_build():
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
running_idx += additional_obs.shape[0]
for additional_per_agent_obs in self.additional_per_agent_obs_build(agent):
obs[running_idx:running_idx + additional_per_agent_obs.shape[0]] = additional_per_agent_obs
running_idx += additional_per_agent_obs.shape[0]
return obs
@ -522,6 +538,10 @@ class BaseFactory(gym.Env):
def additional_obs_build(self) -> List[np.ndarray]:
return []
@abc.abstractmethod
def additional_per_agent_obs_build(self, agent) -> List[np.ndarray]:
return []
@abc.abstractmethod
def do_additional_reset(self) -> None:
pass