documentation obsbuilder, raycaster, logging, renderer

This commit is contained in:
Chanumask 2024-01-31 15:05:03 +01:00
parent 26a59b5c01
commit f62afefa20
6 changed files with 216 additions and 23 deletions

View File

@ -16,7 +16,10 @@ class LevelParser(object):
@property
def pomdp_d(self):
"""
Internal Usage
Calculates the effective diameter of the POMDP observation space.
:return: The calculated effective diameter.
:rtype: int
"""
return self.pomdp_r * 2 + 1

View File

@ -17,6 +17,9 @@ class EnvMonitor(Wrapper):
ext = 'png'
def __init__(self, env, filepath: Union[str, PathLike] = None):
"""
EnvMonitor is a wrapper for Gymnasium environments that monitors and logs key information during interactions.
"""
super(EnvMonitor, self).__init__(env)
self._filepath = filepath
self._monitor_df = pd.DataFrame()
@ -52,6 +55,14 @@ class EnvMonitor(Wrapper):
return
def save_monitor(self, filepath: Union[Path, str, None] = None, auto_plotting_keys=None):
"""
Saves the monitoring data to a file and optionally generates plots.
:param filepath: The path to save the monitoring data file.
:type filepath: Union[Path, str, None]
:param auto_plotting_keys: Keys to use for automatic plot generation.
:type auto_plotting_keys: Any
"""
filepath = Path(filepath or self._filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)
with filepath.open('wb') as f:

View File

@ -12,11 +12,14 @@ class EnvRecorder(Wrapper):
def __init__(self, env, filepath: Union[str, PathLike] = None,
episodes: Union[List[int], None] = None):
"""
EnvRecorder is a wrapper for OpenAI Gym environments that records state summaries during interactions.
Todo
:param env:
:param filepath:
:param env: The environment to record.
:type env: gym.Env
:param filepath: The path to save the recording data file.
:type filepath: Union[str, PathLike]
:param episodes: A list of episode numbers to record. If None, records all episodes.
:type episodes: Union[List[int], None]
"""
super(EnvRecorder, self).__init__(env)
self.filepath = filepath
@ -26,6 +29,9 @@ class EnvRecorder(Wrapper):
self._recorder_out_list = list()
def reset(self):
"""
Overrides the reset method to reset the environment and recording lists.
"""
self._curr_ep_recorder = list()
self._recorder_out_list = list()
self._curr_episode += 1
@ -33,10 +39,12 @@ class EnvRecorder(Wrapper):
def step(self, actions):
"""
Todo
Overrides the step method to record state summaries during each step.
:param actions:
:return:
:param actions: The actions taken in the environment.
:type actions: Any
:return: The observation, reward, done flag, and additional information.
:rtype: Tuple
"""
obs_type, obs, reward, done, info = self.env.step(actions)
if not self.episodes or self._curr_episode in self.episodes:
@ -62,6 +70,18 @@ class EnvRecorder(Wrapper):
save_occupation_map=False,
save_trajectory_map=False,
):
"""
Saves the recorded data to a file.
:param filepath: The path to save the recording data file.
:type filepath: Union[Path, str, None]
:param only_deltas: If True, saves only the differences between consecutive episodes.
:type only_deltas: bool
:param save_occupation_map: If True, saves an occupation map as a heatmap.
:type save_occupation_map: bool
:param save_trajectory_map: If True, saves a trajectory map.
:type save_trajectory_map: bool
"""
self._finalize()
filepath = Path(filepath or self.filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)

View File

@ -19,10 +19,10 @@ class OBSBuilder(object):
@property
def pomdp_d(self):
"""
TODO
Calculates the effective diameter of the POMDP observation space.
:return:
:return: The calculated effective diameter.
:rtype: int
"""
if self.pomdp_r:
return (self.pomdp_r * 2) + 1
@ -34,10 +34,14 @@ class OBSBuilder(object):
OBSBuilder
==========
TODO
The OBSBuilder class is responsible for constructing observations in the environment.
:return:
:param level_shape: The shape of the level or environment.
:type level_shape: np.size
:param state: The current game state.
:type state: marl_factory_grid.environment.state.Gamestate
:param pomdp_r: The POMDP radius, influencing the size of the observation space.
:type pomdp_r: int
"""
self.all_obs = dict()
self.ray_caster = dict()
@ -55,6 +59,9 @@ class OBSBuilder(object):
self.reset(state)
def reset(self, state):
"""
Resets temporary information and constructs an empty observation array with possible placeholders.
"""
# Reset temporary information
self.curr_lightmaps = dict()
# Construct an empty obs (array) for possible placeholders
@ -64,6 +71,11 @@ class OBSBuilder(object):
return True
def observation_space(self, state):
"""
Returns the observation space for a single agent or a tuple of spaces for multiple agents.
:returns: The observation space for the agent(s).
:rtype: gym.Space|Tuple
"""
from gymnasium.spaces import Tuple, Box
self.reset(state)
obsn = self.build_for_all(state)
@ -74,13 +86,29 @@ class OBSBuilder(object):
return space
def named_observation_space(self, state):
"""
:returns: A dictionary of named observation spaces for all agents.
:rtype: dict
"""
self.reset(state)
return self.build_for_all(state)
def build_for_all(self, state) -> (dict, dict):
"""
Builds observations for all agents in the environment.
:returns: A dictionary of observations for all agents.
:rtype: dict
"""
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}
def build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
"""
Builds named observations for all agents in the environment.
:returns: A dictionary containing named observations for all agents.
:rtype: dict
"""
named_obs_dict = {}
for agent in state[c.AGENT]:
obs, names = self.build_for_agent(agent, state)
@ -88,6 +116,16 @@ class OBSBuilder(object):
return named_obs_dict
def place_entity_in_observation(self, obs_array, agent, e):
"""
Places the encoding of an entity in the observation array relative to the agent's position.
:param obs_array: The observation array.
:type obs_array: np.ndarray
:param agent: the associated agent
:type agent: Agent
:param e: The entity to be placed in the observation.
:type e: Entity
"""
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
if not min([y, x]) < 0:
try:
@ -98,6 +136,12 @@ class OBSBuilder(object):
pass
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
"""
Builds observations for a specific agent.
:returns: A tuple containing a list of observation names and the corresponding observation array
:rtype: Tuple[List[str], np.ndarray]
"""
try:
agent_want_obs = self.obs_layers[agent.name]
except KeyError:
@ -193,8 +237,8 @@ class OBSBuilder(object):
def _sort_and_name_observation_conf(self, agent):
"""
Builds the useable observation scheme per agent from conf.yaml.
:param agent:
:return:
:param agent: The agent for whom the observation scheme is built.
"""
# Fixme: no asymetric shapes possible.
self.ray_caster[agent.name] = RayCaster(agent, min(self.obs_shape))

View File

@ -8,10 +8,17 @@ from numba import njit
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
"""
TODO
The RayCaster class enables agents in the environment to simulate field-of-view visibility,
providing methods for calculating visible entities and outlining the field of view based on
Bresenham's algorithm.
:return:
:param agent: The agent for which the RayCaster is initialized.
:type agent: Agent
:param pomdp_r: The range of the partially observable Markov decision process (POMDP).
:type pomdp_r: int
:param degs: The degrees of the field of view (FOV). Defaults to 360.
:type degs: int
:return: None
"""
self.agent = agent
self.pomdp_r = pomdp_r
@ -25,6 +32,12 @@ class RayCaster:
return f'{self.__class__.__name__}({self.agent.name})'
def build_ray_targets(self):
"""
Builds the targets for the rays based on the field of view (FOV).
:return: The targets for the rays.
:rtype: np.ndarray
"""
north = np.array([0, -1]) * self.pomdp_r
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
rot_M = [
@ -36,11 +49,31 @@ class RayCaster:
return rot_M.astype(int)
def ray_block_cache(self, key, callback):
"""
Retrieves or caches a value in the cache dictionary.
:param key: The key for the cache dictionary.
:type key: any
:param callback: The callback function to obtain the value if not present in the cache.
:type callback: callable
:return: The cached or newly computed value.
:rtype: any
"""
if key not in self._cache_dict:
self._cache_dict[key] = callback()
return self._cache_dict[key]
def visible_entities(self, pos_dict, reset_cache=True):
"""
Returns a list of visible entities based on the agent's field of view.
:param pos_dict: The dictionary containing positions of entities.
:type pos_dict: dict
:param reset_cache: Flag to reset the cache. Defaults to True.
:type reset_cache: bool
:return: A list of visible entities.
:rtype: list
"""
visible = list()
if reset_cache:
self._cache_dict = dict()
@ -71,15 +104,33 @@ class RayCaster:
return visible
def get_rays(self):
"""
Gets the rays for the agent.
:return: The rays for the agent.
:rtype: list
"""
a_pos = self.agent.pos
outline = self.ray_targets + a_pos
return self.bresenham_loop(a_pos, outline)
# todo do this once and cache the points!
def get_fov_outline(self) -> np.ndarray:
"""
Gets the field of view (FOV) outline.
:return: The FOV outline.
:rtype: np.ndarray
"""
return self.ray_targets + self.agent.pos
def get_square_outline(self):
"""
Gets the square outline for the agent.
:return: The square outline.
:rtype: list
"""
agent = self.agent
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
@ -90,6 +141,16 @@ class RayCaster:
@staticmethod
@njit
def bresenham_loop(a_pos, points):
"""
Applies Bresenham's algorithm to calculate the points between two positions.
:param a_pos: The starting position.
:type a_pos: list
:param points: The ending positions.
:type points: list
:return: The list of points between the starting and ending positions.
:rtype: list
"""
results = []
for end in points:
x1, y1 = a_pos

View File

@ -34,12 +34,26 @@ class Renderer:
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
"""
TODO
The Renderer class initializes and manages the rendering environment for the simulation,
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
rendering the entities on the screen with specified parameters.
:return:
:param lvl_shape: Tuple representing the shape of the level.
:type lvl_shape: Tuple[int, int]
:param lvl_padded_shape: Optional Tuple representing the padded shape of the level.
:type lvl_padded_shape: Union[Tuple[int, int], None]
:param cell_size: Size of each cell in pixels.
:type cell_size: int
:param fps: Frames per second for rendering.
:type fps: int
:param factor: Factor for resizing assets.
:type factor: float
:param grid_lines: Boolean indicating whether to display grid lines.
:type grid_lines: bool
:param view_radius: Radius for agent's field of view.
:type view_radius: int
"""
# TODO: Customn_assets paths
# TODO: Custom_assets paths
self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
self.cell_size = cell_size
@ -60,6 +74,9 @@ class Renderer:
print('Loading System font with pygame.font.Font took', time.time() - now)
def fill_bg(self):
"""
Fills the background of the screen with the specified BG color.
"""
self.screen.fill(Renderer.BG_COLOR)
if self.grid_lines:
w, h = self.screen_size
@ -69,6 +86,16 @@ class Renderer:
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
def blit_params(self, entity):
"""
Prepares parameters for blitting an entity on the screen. Blitting refers to the process of combining or copying
rectangular blocks of pixels from one part of a graphical buffer to another and is often used to efficiently
update the display by copying pre-drawn or cached images onto the screen.
:param entity: The entity to be blitted.
:type entity: Entity
:return: Dictionary containing source and destination information for blitting.
:rtype: dict
"""
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
(self.lvl_padded_shape[1] - self.grid_w) // 2
@ -90,12 +117,31 @@ class Renderer:
return dict(source=img, dest=rect)
def load_asset(self, path, factor=1.0):
"""
Loads and resizes an asset from the specified path.
:param path: Path to the asset.
:type path: str
:param factor: Resizing factor for the asset.
:type factor: float
:return: Resized asset.
"""
s = int(factor*self.cell_size)
asset = pygame.image.load(path).convert_alpha()
asset = pygame.transform.smoothscale(asset, (s, s))
return asset
def visibility_rects(self, bp, view):
"""
Calculates the visibility rectangles for an agent.
:param bp: Blit parameters for the agent.
:type bp: dict
:param view: Agent's field of view.
:type view: np.ndarray
:return: List of visibility rectangles.
:rtype: List[dict]
"""
rects = []
for i, j in product(range(-self.view_radius, self.view_radius+1),
range(-self.view_radius, self.view_radius+1)):
@ -111,6 +157,14 @@ class Renderer:
return rects
def render(self, entities):
"""
Renders the entities on the screen.
:param entities: List of entities to be rendered.
:type entities: List[Entity]
:return: Transposed RGB observation array.
:rtype: np.ndarray
"""
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()