mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
documentation obsbuilder, raycaster, logging, renderer
This commit is contained in:
parent
26a59b5c01
commit
f62afefa20
@ -16,7 +16,10 @@ class LevelParser(object):
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
Internal Usage
|
||||
Calculates the effective diameter of the POMDP observation space.
|
||||
|
||||
:return: The calculated effective diameter.
|
||||
:rtype: int
|
||||
"""
|
||||
return self.pomdp_r * 2 + 1
|
||||
|
||||
|
@ -17,6 +17,9 @@ class EnvMonitor(Wrapper):
|
||||
ext = 'png'
|
||||
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None):
|
||||
"""
|
||||
EnvMonitor is a wrapper for Gymnasium environments that monitors and logs key information during interactions.
|
||||
"""
|
||||
super(EnvMonitor, self).__init__(env)
|
||||
self._filepath = filepath
|
||||
self._monitor_df = pd.DataFrame()
|
||||
@ -52,6 +55,14 @@ class EnvMonitor(Wrapper):
|
||||
return
|
||||
|
||||
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
|
||||
"""
|
||||
Saves the monitoring data to a file and optionally generates plots.
|
||||
|
||||
:param filepath: The path to save the monitoring data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param auto_plotting_keys: Keys to use for automatic plot generation.
|
||||
:type auto_plotting_keys: Any
|
||||
"""
|
||||
filepath = Path(filepath or self._filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with filepath.open('wb') as f:
|
||||
|
@ -12,11 +12,14 @@ class EnvRecorder(Wrapper):
|
||||
def __init__(self, env, filepath: Union[str, PathLike] = None,
|
||||
episodes: Union[List[int], None] = None):
|
||||
"""
|
||||
EnvRecorder is a wrapper for OpenAI Gym environments that records state summaries during interactions.
|
||||
|
||||
Todo
|
||||
|
||||
:param env:
|
||||
:param filepath:
|
||||
:param env: The environment to record.
|
||||
:type env: gym.Env
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[str, PathLike]
|
||||
:param episodes: A list of episode numbers to record. If None, records all episodes.
|
||||
:type episodes: Union[List[int], None]
|
||||
"""
|
||||
super(EnvRecorder, self).__init__(env)
|
||||
self.filepath = filepath
|
||||
@ -26,6 +29,9 @@ class EnvRecorder(Wrapper):
|
||||
self._recorder_out_list = list()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Overrides the reset method to reset the environment and recording lists.
|
||||
"""
|
||||
self._curr_ep_recorder = list()
|
||||
self._recorder_out_list = list()
|
||||
self._curr_episode += 1
|
||||
@ -33,10 +39,12 @@ class EnvRecorder(Wrapper):
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Todo
|
||||
Overrides the step method to record state summaries during each step.
|
||||
|
||||
:param actions:
|
||||
:return:
|
||||
:param actions: The actions taken in the environment.
|
||||
:type actions: Any
|
||||
:return: The observation, reward, done flag, and additional information.
|
||||
:rtype: Tuple
|
||||
"""
|
||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||
if not self.episodes or self._curr_episode in self.episodes:
|
||||
@ -62,6 +70,18 @@ class EnvRecorder(Wrapper):
|
||||
save_occupation_map=False,
|
||||
save_trajectory_map=False,
|
||||
):
|
||||
"""
|
||||
Saves the recorded data to a file.
|
||||
|
||||
:param filepath: The path to save the recording data file.
|
||||
:type filepath: Union[Path, str, None]
|
||||
:param only_deltas: If True, saves only the differences between consecutive episodes.
|
||||
:type only_deltas: bool
|
||||
:param save_occupation_map: If True, saves an occupation map as a heatmap.
|
||||
:type save_occupation_map: bool
|
||||
:param save_trajectory_map: If True, saves a trajectory map.
|
||||
:type save_trajectory_map: bool
|
||||
"""
|
||||
self._finalize()
|
||||
filepath = Path(filepath or self.filepath)
|
||||
filepath.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -19,10 +19,10 @@ class OBSBuilder(object):
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
TODO
|
||||
Calculates the effective diameter of the POMDP observation space.
|
||||
|
||||
|
||||
:return:
|
||||
:return: The calculated effective diameter.
|
||||
:rtype: int
|
||||
"""
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
@ -34,10 +34,14 @@ class OBSBuilder(object):
|
||||
OBSBuilder
|
||||
==========
|
||||
|
||||
TODO
|
||||
The OBSBuilder class is responsible for constructing observations in the environment.
|
||||
|
||||
|
||||
:return:
|
||||
:param level_shape: The shape of the level or environment.
|
||||
:type level_shape: np.size
|
||||
:param state: The current game state.
|
||||
:type state: marl_factory_grid.environment.state.Gamestate
|
||||
:param pomdp_r: The POMDP radius, influencing the size of the observation space.
|
||||
:type pomdp_r: int
|
||||
"""
|
||||
self.all_obs = dict()
|
||||
self.ray_caster = dict()
|
||||
@ -55,6 +59,9 @@ class OBSBuilder(object):
|
||||
self.reset(state)
|
||||
|
||||
def reset(self, state):
|
||||
"""
|
||||
Resets temporary information and constructs an empty observation array with possible placeholders.
|
||||
"""
|
||||
# Reset temporary information
|
||||
self.curr_lightmaps = dict()
|
||||
# Construct an empty obs (array) for possible placeholders
|
||||
@ -64,6 +71,11 @@ class OBSBuilder(object):
|
||||
return True
|
||||
|
||||
def observation_space(self, state):
|
||||
"""
|
||||
Returns the observation space for a single agent or a tuple of spaces for multiple agents.
|
||||
:returns: The observation space for the agent(s).
|
||||
:rtype: gym.Space|Tuple
|
||||
"""
|
||||
from gymnasium.spaces import Tuple, Box
|
||||
self.reset(state)
|
||||
obsn = self.build_for_all(state)
|
||||
@ -74,13 +86,29 @@ class OBSBuilder(object):
|
||||
return space
|
||||
|
||||
def named_observation_space(self, state):
|
||||
"""
|
||||
:returns: A dictionary of named observation spaces for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
self.reset(state)
|
||||
return self.build_for_all(state)
|
||||
|
||||
def build_for_all(self, state) -> (dict, dict):
|
||||
"""
|
||||
Builds observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary of observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
|
||||
|
||||
def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
|
||||
"""
|
||||
Builds named observations for all agents in the environment.
|
||||
|
||||
:returns: A dictionary containing named observations for all agents.
|
||||
:rtype: dict
|
||||
"""
|
||||
named_obs_dict = {}
|
||||
for agent in state[c.AGENT]:
|
||||
obs, names = self.build_for_agent(agent, state)
|
||||
@ -88,6 +116,16 @@ class OBSBuilder(object):
|
||||
return named_obs_dict
|
||||
|
||||
def place_entity_in_observation(self, obs_array, agent, e):
|
||||
"""
|
||||
Places the encoding of an entity in the observation array relative to the agent's position.
|
||||
|
||||
:param obs_array: The observation array.
|
||||
:type obs_array: np.ndarray
|
||||
:param agent: the associated agent
|
||||
:type agent: Agent
|
||||
:param e: The entity to be placed in the observation.
|
||||
:type e: Entity
|
||||
"""
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
if not min([y, x]) < 0:
|
||||
try:
|
||||
@ -98,6 +136,12 @@ class OBSBuilder(object):
|
||||
pass
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
"""
|
||||
Builds observations for a specific agent.
|
||||
|
||||
:returns: A tuple containing a list of observation names and the corresponding observation array
|
||||
:rtype: Tuple[List[str], np.ndarray]
|
||||
"""
|
||||
try:
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
except KeyError:
|
||||
@ -193,8 +237,8 @@ class OBSBuilder(object):
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
"""
|
||||
Builds the useable observation scheme per agent from conf.yaml.
|
||||
:param agent:
|
||||
:return:
|
||||
|
||||
:param agent: The agent for whom the observation scheme is built.
|
||||
"""
|
||||
# Fixme: no asymetric shapes possible.
|
||||
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))
|
||||
|
@ -8,10 +8,17 @@ from numba import njit
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
"""
|
||||
TODO
|
||||
The RayCaster class enables agents in the environment to simulate field-of-view visibility,
|
||||
providing methods for calculating visible entities and outlining the field of view based on
|
||||
Bresenham's algorithm.
|
||||
|
||||
|
||||
:return:
|
||||
:param agent: The agent for which the RayCaster is initialized.
|
||||
:type agent: Agent
|
||||
:param pomdp_r: The range of the partially observable Markov decision process (POMDP).
|
||||
:type pomdp_r: int
|
||||
:param degs: The degrees of the field of view (FOV). Defaults to 360.
|
||||
:type degs: int
|
||||
:return: None
|
||||
"""
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
@ -25,6 +32,12 @@ class RayCaster:
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
"""
|
||||
Builds the targets for the rays based on the field of view (FOV).
|
||||
|
||||
:return: The targets for the rays.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
north = np.array([0, -1]) * self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
@ -36,11 +49,31 @@ class RayCaster:
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, key, callback):
|
||||
"""
|
||||
Retrieves or caches a value in the cache dictionary.
|
||||
|
||||
:param key: The key for the cache dictionary.
|
||||
:type key: any
|
||||
:param callback: The callback function to obtain the value if not present in the cache.
|
||||
:type callback: callable
|
||||
:return: The cached or newly computed value.
|
||||
:rtype: any
|
||||
"""
|
||||
if key not in self._cache_dict:
|
||||
self._cache_dict[key] = callback()
|
||||
return self._cache_dict[key]
|
||||
|
||||
def visible_entities(self, pos_dict, reset_cache=True):
|
||||
"""
|
||||
Returns a list of visible entities based on the agent's field of view.
|
||||
|
||||
:param pos_dict: The dictionary containing positions of entities.
|
||||
:type pos_dict: dict
|
||||
:param reset_cache: Flag to reset the cache. Defaults to True.
|
||||
:type reset_cache: bool
|
||||
:return: A list of visible entities.
|
||||
:rtype: list
|
||||
"""
|
||||
visible = list()
|
||||
if reset_cache:
|
||||
self._cache_dict = dict()
|
||||
@ -71,15 +104,33 @@ class RayCaster:
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
"""
|
||||
Gets the rays for the agent.
|
||||
|
||||
:return: The rays for the agent.
|
||||
:rtype: list
|
||||
"""
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
"""
|
||||
Gets the field of view (FOV) outline.
|
||||
|
||||
:return: The FOV outline.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
"""
|
||||
Gets the square outline for the agent.
|
||||
|
||||
:return: The square outline.
|
||||
:rtype: list
|
||||
"""
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
@ -90,6 +141,16 @@ class RayCaster:
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
"""
|
||||
Applies Bresenham's algorithm to calculate the points between two positions.
|
||||
|
||||
:param a_pos: The starting position.
|
||||
:type a_pos: list
|
||||
:param points: The ending positions.
|
||||
:type points: list
|
||||
:return: The list of points between the starting and ending positions.
|
||||
:rtype: list
|
||||
"""
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
|
@ -34,12 +34,26 @@ class Renderer:
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
"""
|
||||
TODO
|
||||
The Renderer class initializes and manages the rendering environment for the simulation,
|
||||
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
|
||||
rendering the entities on the screen with specified parameters.
|
||||
|
||||
|
||||
:return:
|
||||
:param lvl_shape: Tuple representing the shape of the level.
|
||||
:type lvl_shape: Tuple[int, int]
|
||||
:param lvl_padded_shape: Optional Tuple representing the padded shape of the level.
|
||||
:type lvl_padded_shape: Union[Tuple[int, int], None]
|
||||
:param cell_size: Size of each cell in pixels.
|
||||
:type cell_size: int
|
||||
:param fps: Frames per second for rendering.
|
||||
:type fps: int
|
||||
:param factor: Factor for resizing assets.
|
||||
:type factor: float
|
||||
:param grid_lines: Boolean indicating whether to display grid lines.
|
||||
:type grid_lines: bool
|
||||
:param view_radius: Radius for agent's field of view.
|
||||
:type view_radius: int
|
||||
"""
|
||||
# TODO: Customn_assets paths
|
||||
# TODO: Custom_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
@ -60,6 +74,9 @@ class Renderer:
|
||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
|
||||
def fill_bg(self):
|
||||
"""
|
||||
Fills the background of the screen with the specified BG color.
|
||||
"""
|
||||
self.screen.fill(Renderer.BG_COLOR)
|
||||
if self.grid_lines:
|
||||
w, h = self.screen_size
|
||||
@ -69,6 +86,16 @@ class Renderer:
|
||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||
|
||||
def blit_params(self, entity):
|
||||
"""
|
||||
Prepares parameters for blitting an entity on the screen. Blitting refers to the process of combining or copying
|
||||
rectangular blocks of pixels from one part of a graphical buffer to another and is often used to efficiently
|
||||
update the display by copying pre-drawn or cached images onto the screen.
|
||||
|
||||
:param entity: The entity to be blitted.
|
||||
:type entity: Entity
|
||||
:return: Dictionary containing source and destination information for blitting.
|
||||
:rtype: dict
|
||||
"""
|
||||
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
|
||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||
|
||||
@ -90,12 +117,31 @@ class Renderer:
|
||||
return dict(source=img, dest=rect)
|
||||
|
||||
def load_asset(self, path, factor=1.0):
|
||||
"""
|
||||
Loads and resizes an asset from the specified path.
|
||||
|
||||
:param path: Path to the asset.
|
||||
:type path: str
|
||||
:param factor: Resizing factor for the asset.
|
||||
:type factor: float
|
||||
:return: Resized asset.
|
||||
"""
|
||||
s = int(factor*self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
|
||||
def visibility_rects(self, bp, view):
|
||||
"""
|
||||
Calculates the visibility rectangles for an agent.
|
||||
|
||||
:param bp: Blit parameters for the agent.
|
||||
:type bp: dict
|
||||
:param view: Agent's field of view.
|
||||
:type view: np.ndarray
|
||||
:return: List of visibility rectangles.
|
||||
:rtype: List[dict]
|
||||
"""
|
||||
rects = []
|
||||
for i, j in product(range(-self.view_radius, self.view_radius+1),
|
||||
range(-self.view_radius, self.view_radius+1)):
|
||||
@ -111,6 +157,14 @@ class Renderer:
|
||||
return rects
|
||||
|
||||
def render(self, entities):
|
||||
"""
|
||||
Renders the entities on the screen.
|
||||
|
||||
:param entities: List of entities to be rendered.
|
||||
:type entities: List[Entity]
|
||||
:return: Transposed RGB observation array.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
|
Loading…
x
Reference in New Issue
Block a user