From d29ccbbb7153ad1a29361b13642601f79d071c11 Mon Sep 17 00:00:00 2001 From: Steffen Illium Date: Tue, 11 Jan 2022 18:00:24 +0100 Subject: [PATCH] Fixed Global Positions --- environments/factory/base/base_factory.py | 4 +++- environments/factory/base/objects.py | 13 ++++++------- environments/factory/base/registers.py | 2 +- reload_agent.py | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/environments/factory/base/base_factory.py b/environments/factory/base/base_factory.py index f2ec6e0..ce051ec 100644 --- a/environments/factory/base/base_factory.py +++ b/environments/factory/base/base_factory.py @@ -697,7 +697,9 @@ class BaseFactory(gym.Env): def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]: additional_raw_observations = {} if self.obs_prop.show_global_position_info: - additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()}) + global_pos_obs = np.zeros(self._obs_shape) + global_pos_obs[:2, 0] = self[c.GLOBAL_POSITION].by_entity(agent).encoding + additional_raw_observations.update({c.GLOBAL_POSITION: global_pos_obs}) return additional_raw_observations @abc.abstractmethod diff --git a/environments/factory/base/objects.py b/environments/factory/base/objects.py index f3d4b5e..ec72c14 100644 --- a/environments/factory/base/objects.py +++ b/environments/factory/base/objects.py @@ -86,8 +86,8 @@ class BoundingMixin(Object): def bound_entity(self): return self._bound_entity - def __init__(self, entity_to_be_bound, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self,entity_to_be_bound, *args, **kwargs): + super(BoundingMixin, self).__init__(*args, **kwargs) assert entity_to_be_bound is not None self._bound_entity = entity_to_be_bound @@ -201,18 +201,17 @@ class PlaceHolder(Object): return "PlaceHolder" -class GlobalPosition(EnvObject, BoundingMixin): +class GlobalPosition(BoundingMixin, EnvObject): @property def encoding(self): if self._normalized: - return tuple(np.diff(self._bound_entity.pos, self._level_shape)) + return tuple(np.divide(self._bound_entity.pos, self._level_shape)) else: return self.bound_entity.pos - def __init__(self, level_shape, *args, normalized: bool = True, **kwargs): - super(GlobalPosition, self).__init__(self, *args, **kwargs) - + def __init__(self, level_shape: (int, int), *args, normalized: bool = True, **kwargs): + super(GlobalPosition, self).__init__(*args, **kwargs) self._level_shape = level_shape self._normalized = normalized diff --git a/environments/factory/base/registers.py b/environments/factory/base/registers.py index 5cf41b1..0985162 100644 --- a/environments/factory/base/registers.py +++ b/environments/factory/base/registers.py @@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister): def spawn_global_position_objects(self, agents): # Todo, change to 'from xy'-form - global_positions = [self._accepted_objects(self._shape, agent) + global_positions = [self._accepted_objects(self._shape, agent, self) for _, agent in enumerate(agents)] # noinspection PyTypeChecker self.register_additional_items(global_positions) diff --git a/reload_agent.py b/reload_agent.py index 9a2e7e7..eedd864 100644 --- a/reload_agent.py +++ b/reload_agent.py @@ -19,7 +19,7 @@ if __name__ == '__main__': seed = 67 n_agents = 1 # out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward') - out_path = Path('study_out/single_run_with_export/dirt') + out_path = Path('study_out/test/dirt') model_path = out_path with (out_path / f'env_params.json').open('r') as f: