Error Resolvement

This commit is contained in:
Steffen Illium
2021-09-07 17:41:15 +02:00
parent 444ffe3f37
commit 50c0d90c77
4 changed files with 55 additions and 40 deletions

View File

@ -278,22 +278,23 @@ class BaseFactory(gym.Env):
for key, array in state_array_dict.items():
# Flush state array object representation to obs cube
if self[key].is_per_agent:
per_agent_idx = self[key].idx_by_entity(agent)
z = 1
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
# Define which OBS SLices cast a Shadow
if self[key].is_blocking_light:
for i in range(z):
shadowing_idxs.append(running_idx + i)
# Define which OBS SLices are effected by shadows
if self[key].can_be_shadowed:
for i in range(z):
can_be_shadowed_idxs.append(running_idx + i)
running_idx += z
if not self[key].hide_from_obs_builder:
if self[key].is_per_agent:
per_agent_idx = self[key].idx_by_entity(agent)
z = 1
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
else:
z = array.shape[0]
self._obs_cube[running_idx: running_idx+z] = array
# Define which OBS SLices cast a Shadow
if self[key].is_blocking_light:
for i in range(z):
shadowing_idxs.append(running_idx + i)
# Define which OBS SLices are effected by shadows
if self[key].can_be_shadowed:
for i in range(z):
can_be_shadowed_idxs.append(running_idx + i)
running_idx += z
if agent_pos_is_omitted:
state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding
@ -341,10 +342,14 @@ class BaseFactory(gym.Env):
agent.temp_light_map = light_block_map
for obs_idx in can_be_shadowed_idxs:
obs[obs_idx] = ((obs[obs_idx] * light_block_map) + 0.) - (1 - light_block_map) # * obs[0])
return obs
else:
return obs
pass
for additional_obs in self.additional_obs_build():
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
running_idx += additional_obs.shape[0]
return obs
def get_all_tiles_with_collisions(self) -> List[Tile]:
tiles_with_collisions = list()
@ -479,7 +484,8 @@ class BaseFactory(gym.Env):
summary = {f'{REC_TAC}step': self._steps}
if self._steps == 0:
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()}})
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()},
'FactoryName': self.__class__.__name__})
for entity_group in self._entities:
if not isinstance(entity_group, WallTiles):
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
@ -512,6 +518,10 @@ class BaseFactory(gym.Env):
# Functions which provide additions to functions of the base class
# Always call super!!!!!!
@abc.abstractmethod
def additional_obs_build(self) -> List[np.ndarray]:
return []
@abc.abstractmethod
def do_additional_reset(self) -> None:
pass

View File

@ -63,6 +63,9 @@ class Register:
class ObjectRegister(Register):
hide_from_obs_builder = False
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
super(ObjectRegister, self).__init__(*args, **kwargs)
self.is_per_agent = is_per_agent
@ -76,7 +79,7 @@ class ObjectRegister(Register):
self._array = np.zeros((1, *self._level_shape))
else:
if self.individual_slices:
self._array = np.concatenate((self._array, np.zeros((1, *self._level_shape))))
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
def summarize_states(self):
return [val.summarize_state() for val in self.values()]