mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
Fixed Global Positions
This commit is contained in:
parent
2a2aafa988
commit
d29ccbbb71
@ -697,7 +697,9 @@ class BaseFactory(gym.Env):
|
||||
def _additional_per_agent_raw_observations(self, agent) -> Dict[str, np.typing.ArrayLike]:
|
||||
additional_raw_observations = {}
|
||||
if self.obs_prop.show_global_position_info:
|
||||
additional_raw_observations.update({c.GLOBAL_POSITION: self[c.GLOBAL_POSITION].by_entity(agent).as_array()})
|
||||
global_pos_obs = np.zeros(self._obs_shape)
|
||||
global_pos_obs[:2, 0] = self[c.GLOBAL_POSITION].by_entity(agent).encoding
|
||||
additional_raw_observations.update({c.GLOBAL_POSITION: global_pos_obs})
|
||||
return additional_raw_observations
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -86,8 +86,8 @@ class BoundingMixin(Object):
|
||||
def bound_entity(self):
|
||||
return self._bound_entity
|
||||
|
||||
def __init__(self, entity_to_be_bound, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self,entity_to_be_bound, *args, **kwargs):
|
||||
super(BoundingMixin, self).__init__(*args, **kwargs)
|
||||
assert entity_to_be_bound is not None
|
||||
self._bound_entity = entity_to_be_bound
|
||||
|
||||
@ -201,18 +201,17 @@ class PlaceHolder(Object):
|
||||
return "PlaceHolder"
|
||||
|
||||
|
||||
class GlobalPosition(EnvObject, BoundingMixin):
|
||||
class GlobalPosition(BoundingMixin, EnvObject):
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
if self._normalized:
|
||||
return tuple(np.diff(self._bound_entity.pos, self._level_shape))
|
||||
return tuple(np.divide(self._bound_entity.pos, self._level_shape))
|
||||
else:
|
||||
return self.bound_entity.pos
|
||||
|
||||
def __init__(self, level_shape, *args, normalized: bool = True, **kwargs):
|
||||
super(GlobalPosition, self).__init__(self, *args, **kwargs)
|
||||
|
||||
def __init__(self, level_shape: (int, int), *args, normalized: bool = True, **kwargs):
|
||||
super(GlobalPosition, self).__init__(*args, **kwargs)
|
||||
self._level_shape = level_shape
|
||||
self._normalized = normalized
|
||||
|
||||
|
@ -288,7 +288,7 @@ class GlobalPositions(EnvObjectRegister):
|
||||
|
||||
def spawn_global_position_objects(self, agents):
|
||||
# Todo, change to 'from xy'-form
|
||||
global_positions = [self._accepted_objects(self._shape, agent)
|
||||
global_positions = [self._accepted_objects(self._shape, agent, self)
|
||||
for _, agent in enumerate(agents)]
|
||||
# noinspection PyTypeChecker
|
||||
self.register_additional_items(global_positions)
|
||||
|
@ -19,7 +19,7 @@ if __name__ == '__main__':
|
||||
seed = 67
|
||||
n_agents = 1
|
||||
# out_path = Path('study_out/e_1_new_reward/no_obs/dirt/A2C_new_reward/0_A2C_new_reward')
|
||||
out_path = Path('study_out/single_run_with_export/dirt')
|
||||
out_path = Path('study_out/test/dirt')
|
||||
model_path = out_path
|
||||
|
||||
with (out_path / f'env_params.json').open('r') as f:
|
||||
|
Loading…
x
Reference in New Issue
Block a user