Error Resolvement
This commit is contained in:
@ -278,22 +278,23 @@ class BaseFactory(gym.Env):
|
||||
|
||||
for key, array in state_array_dict.items():
|
||||
# Flush state array object representation to obs cube
|
||||
if self[key].is_per_agent:
|
||||
per_agent_idx = self[key].idx_by_entity(agent)
|
||||
z = 1
|
||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
||||
else:
|
||||
z = array.shape[0]
|
||||
self._obs_cube[running_idx: running_idx+z] = array
|
||||
# Define which OBS SLices cast a Shadow
|
||||
if self[key].is_blocking_light:
|
||||
for i in range(z):
|
||||
shadowing_idxs.append(running_idx + i)
|
||||
# Define which OBS SLices are effected by shadows
|
||||
if self[key].can_be_shadowed:
|
||||
for i in range(z):
|
||||
can_be_shadowed_idxs.append(running_idx + i)
|
||||
running_idx += z
|
||||
if not self[key].hide_from_obs_builder:
|
||||
if self[key].is_per_agent:
|
||||
per_agent_idx = self[key].idx_by_entity(agent)
|
||||
z = 1
|
||||
self._obs_cube[running_idx: running_idx+z] = array[per_agent_idx]
|
||||
else:
|
||||
z = array.shape[0]
|
||||
self._obs_cube[running_idx: running_idx+z] = array
|
||||
# Define which OBS SLices cast a Shadow
|
||||
if self[key].is_blocking_light:
|
||||
for i in range(z):
|
||||
shadowing_idxs.append(running_idx + i)
|
||||
# Define which OBS SLices are effected by shadows
|
||||
if self[key].can_be_shadowed:
|
||||
for i in range(z):
|
||||
can_be_shadowed_idxs.append(running_idx + i)
|
||||
running_idx += z
|
||||
|
||||
if agent_pos_is_omitted:
|
||||
state_array_dict[c.AGENT][0, agent.x, agent.y] += agent.encoding
|
||||
@ -341,10 +342,14 @@ class BaseFactory(gym.Env):
|
||||
agent.temp_light_map = light_block_map
|
||||
for obs_idx in can_be_shadowed_idxs:
|
||||
obs[obs_idx] = ((obs[obs_idx] * light_block_map) + 0.) - (1 - light_block_map) # * obs[0])
|
||||
|
||||
return obs
|
||||
else:
|
||||
return obs
|
||||
pass
|
||||
|
||||
for additional_obs in self.additional_obs_build():
|
||||
obs[running_idx:running_idx+additional_obs.shape[0]] = additional_obs
|
||||
running_idx += additional_obs.shape[0]
|
||||
|
||||
return obs
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Tile]:
|
||||
tiles_with_collisions = list()
|
||||
@ -479,7 +484,8 @@ class BaseFactory(gym.Env):
|
||||
summary = {f'{REC_TAC}step': self._steps}
|
||||
|
||||
if self._steps == 0:
|
||||
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()}})
|
||||
summary.update({f'{REC_TAC}{self[c.WALLS].name}': {self[c.WALLS].summarize_states()},
|
||||
'FactoryName': self.__class__.__name__})
|
||||
for entity_group in self._entities:
|
||||
if not isinstance(entity_group, WallTiles):
|
||||
summary.update({f'{REC_TAC}{entity_group.name}': entity_group.summarize_states()})
|
||||
@ -512,6 +518,10 @@ class BaseFactory(gym.Env):
|
||||
|
||||
# Functions which provide additions to functions of the base class
|
||||
# Always call super!!!!!!
|
||||
@abc.abstractmethod
|
||||
def additional_obs_build(self) -> List[np.ndarray]:
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def do_additional_reset(self) -> None:
|
||||
pass
|
||||
|
@ -63,6 +63,9 @@ class Register:
|
||||
|
||||
|
||||
class ObjectRegister(Register):
|
||||
|
||||
hide_from_obs_builder = False
|
||||
|
||||
def __init__(self, level_shape: (int, int), *args, individual_slices=False, is_per_agent=False, **kwargs):
|
||||
super(ObjectRegister, self).__init__(*args, **kwargs)
|
||||
self.is_per_agent = is_per_agent
|
||||
@ -76,7 +79,7 @@ class ObjectRegister(Register):
|
||||
self._array = np.zeros((1, *self._level_shape))
|
||||
else:
|
||||
if self.individual_slices:
|
||||
self._array = np.concatenate((self._array, np.zeros((1, *self._level_shape))))
|
||||
self._array = np.concatenate((self._array, np.zeros((1, *self._array.shape[1:]))))
|
||||
|
||||
def summarize_states(self):
|
||||
return [val.summarize_state() for val in self.values()]
|
||||
|
Reference in New Issue
Block a user