Fixed Global Positions

This commit is contained in:
Steffen Illium 2022-01-11 18:00:24 +01:00
parent 2a2aafa988
commit d29ccbbb71
4 changed files with 11 additions and 10 deletions

View File

@ -697,7 +697,9 @@ class BaseFactory(gym.Env):
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
additional_raw_observations = {}
if self.obs_prop.show_global_position_info:
additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()})
global_pos_obs = np.zeros(self._obs_shape)
global_pos_obs[:2, 0] = self[c.GLOBAL_POSITION].by_entity(agent).encoding
additional_raw_observations.update({c.GLOBAL_POSITION: global_pos_obs})
return additional_raw_observations
@abc.abstractmethod

View File

@ -86,8 +86,8 @@ class BoundingMixin(Object):
def bound_entity(self):
return self._bound_entity
def __init__(self, entity_to_be_bound, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self,entity_to_be_bound, *args, **kwargs):
super(BoundingMixin, self).__init__(*args, **kwargs)
assert entity_to_be_bound is not None
self._bound_entity = entity_to_be_bound
@ -201,18 +201,17 @@ class PlaceHolder(Object):
return "PlaceHolder"
class GlobalPosition(EnvObject, BoundingMixin):
class GlobalPosition(BoundingMixin, EnvObject):
@property
def encoding(self):
if self._normalized:
return tuple(np.diff(self._bound_entity.pos, self._level_shape))
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
else:
return self.bound_entity.pos
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(self, *args, **kwargs)
def __init__(self, level_shape: (int, int), *args, normalized: bool = True, **kwargs):
super(GlobalPosition, self).__init__(*args, **kwargs)
self._level_shape = level_shape
self._normalized = normalized

View File

@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
def spawn_global_position_objects(self, agents):
# Todo, change to 'from xy'-form
global_positions = [self._accepted_objects(self._shape, agent)
global_positions = [self._accepted_objects(self._shape, agent, self)
for _, agent in enumerate(agents)]
# noinspection PyTypeChecker
self.register_additional_items(global_positions)

View File

@ -19,7 +19,7 @@ if __name__ == '__main__':
seed = 67
n_agents = 1
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
out_path = Path('study_out/single_run_with_export/dirt')
out_path = Path('study_out/test/dirt')
model_path = out_path
with (out_path / f'env_params.json').open('r') as f: