mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 17:11:35 +02:00
documentation obsbuilder, raycaster, logging, renderer
This commit is contained in:
@ -19,10 +19,10 @@ class OBSBuilder(object):
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
TODO
|
||||
Calculates the effective diameter of the POMDP observation space.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The calculated effective diameter.
|
||||
:rtype: int
|
||||
"""
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
@ -34,10 +34,14 @@ class OBSBuilder(object):
|
||||
OBSBuilder
|
||||
==========
|
||||
|
||||
TODO
|
||||
The OBSBuilder class is responsible for constructing observations in the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:param level_shape: The shape of the level or environment.
|
||||
:type level_shape: np.size
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.environment.state.Gamestate
|
||||
:param pomdp_r: The POMDP radius, influencing the size of the observation space.
|
||||
:type pomdp_r: int
|
||||
"""
|
||||
self.all_obs = dict()
|
||||
self.ray_caster = dict()
|
||||
@ -55,6 +59,9 @@ class OBSBuilder(object):
|
||||
self.reset(state)
|
||||
|
||||
def reset(self, state):
|
||||
"""
|
||||
Resets temporary information and constructs an empty observation array with possible placeholders.
|
||||
"""
|
||||
# Reset temporary information
|
||||
self.curr_lightmaps = dict()
|
||||
# Construct an empty obs (array) for possible placeholders
|
||||
@ -64,6 +71,11 @@ class OBSBuilder(object):
|
||||
return True
|
||||
|
||||
def observation_space(self, state):
|
||||
"""
|
||||
Returns the observation space for a single agent or a tuple of spaces for multiple agents.
|
||||
:returns: The observation space for the agent(s).
|
||||
:rtype: gym.Space|Tuple
|
||||
"""
|
||||
from gymnasium.spaces import Tuple, Box
|
||||
self.reset(state)
|
||||
obsn = self.build_for_all(state)
|
||||
@ -74,13 +86,29 @@ class OBSBuilder(object):
|
||||
return space
|
||||
|
||||
def named_observation_space(self, state):
|
||||
"""
|
||||
:returns: A dictionary of named observation spaces for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
self.reset(state)
|
||||
return self.build_for_all(state)
|
||||
|
||||
def build_for_all(self, state) -> (dict, dict):
|
||||
"""
|
||||
Builds observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary of observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
|
||||
|
||||
def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
|
||||
"""
|
||||
Builds named observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary containing named observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
named_obs_dict = {}
|
||||
for agent in state[c.AGENT]:
|
||||
obs, names = self.build_for_agent(agent, state)
|
||||
@ -88,6 +116,16 @@ class OBSBuilder(object):
|
||||
return named_obs_dict
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
"""
|
||||
Places the encoding of an entity in the observation array relative to the agent's position.
|
||||
|
||||
:param obs_array: The observation array.
|
||||
:type obs_array: np.ndarray
|
||||
:param agent: the associated agent
|
||||
:type agent: Agent
|
||||
:param e: The entity to be placed in the observation.
|
||||
:type e: Entity
|
||||
"""
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
if not min([y, x]) < 0:
|
||||
try:
|
||||
@ -98,6 +136,12 @@ class OBSBuilder(object):
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
"""
|
||||
Builds observations for a specific agent.
|
||||
|
||||
:returns: A tuple containing a list of observation names and the corresponding observation array
|
||||
:rtype: Tuple[List[str], np.ndarray]
|
||||
"""
|
||||
try:
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
except KeyError:
|
||||
@ -193,8 +237,8 @@ class OBSBuilder(object):
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
"""
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
|
||||
:param agent: The agent for whom the observation scheme is built.
|
||||
"""
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
|
Reference in New Issue
Block a user