Fixed Global Positions
This commit is contained in:
parent
2a2aafa988
commit
d29ccbbb71
@ -697,7 +697,9 @@ class BaseFactory(gym.Env):
|
|||||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||||
additional_raw_observations = {}
|
additional_raw_observations = {}
|
||||||
if self.obs_prop.show_global_position_info:
|
if self.obs_prop.show_global_position_info:
|
||||||
additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()})
|
global_pos_obs = np.zeros(self._obs_shape)
|
||||||
|
global_pos_obs[:2, 0] = self[c.GLOBAL_POSITION].by_entity(agent).encoding
|
||||||
|
additional_raw_observations.update({c.GLOBAL_POSITION: global_pos_obs})
|
||||||
return additional_raw_observations
|
return additional_raw_observations
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -86,8 +86,8 @@ class BoundingMixin(Object):
|
|||||||
def bound_entity(self):
|
def bound_entity(self):
|
||||||
return self._bound_entity
|
return self._bound_entity
|
||||||
|
|
||||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
def __init__(self,entity_to_be_bound, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super(BoundingMixin, self).__init__(*args, **kwargs)
|
||||||
assert entity_to_be_bound is not None
|
assert entity_to_be_bound is not None
|
||||||
self._bound_entity = entity_to_be_bound
|
self._bound_entity = entity_to_be_bound
|
||||||
|
|
||||||
@ -201,18 +201,17 @@ class PlaceHolder(Object):
|
|||||||
return "PlaceHolder"
|
return "PlaceHolder"
|
||||||
|
|
||||||
|
|
||||||
class GlobalPosition(EnvObject, BoundingMixin):
|
class GlobalPosition(BoundingMixin, EnvObject):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
if self._normalized:
|
if self._normalized:
|
||||||
return tuple(np.diff(self._bound_entity.pos, self._level_shape))
|
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||||
else:
|
else:
|
||||||
return self.bound_entity.pos
|
return self.bound_entity.pos
|
||||||
|
|
||||||
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
|
def __init__(self, level_shape: (int, int), *args, normalized: bool = True, **kwargs):
|
||||||
super(GlobalPosition, self).__init__(self, *args, **kwargs)
|
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self._level_shape = level_shape
|
self._level_shape = level_shape
|
||||||
self._normalized = normalized
|
self._normalized = normalized
|
||||||
|
|
||||||
|
@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
|
|||||||
|
|
||||||
def spawn_global_position_objects(self, agents):
|
def spawn_global_position_objects(self, agents):
|
||||||
# Todo, change to 'from xy'-form
|
# Todo, change to 'from xy'-form
|
||||||
global_positions = [self._accepted_objects(self._shape, agent)
|
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||||
for _, agent in enumerate(agents)]
|
for _, agent in enumerate(agents)]
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
self.register_additional_items(global_positions)
|
self.register_additional_items(global_positions)
|
||||||
|
@ -19,7 +19,7 @@ if __name__ == '__main__':
|
|||||||
seed = 67
|
seed = 67
|
||||||
n_agents = 1
|
n_agents = 1
|
||||||
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||||
out_path = Path('study_out/single_run_with_export/dirt')
|
out_path = Path('study_out/test/dirt')
|
||||||
model_path = out_path
|
model_path = out_path
|
||||||
|
|
||||||
with (out_path / f'env_params.json').open('r') as f:
|
with (out_path / f'env_params.json').open('r') as f:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user