Documentation

This commit is contained in:
Joel Friedrich
2023-11-22 12:12:04 +01:00
committed by Steffen Illium
parent 604c0c6f57
commit 855f53b406
35 changed files with 655 additions and 198 deletions

View File

@ -1,8 +1,7 @@
import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List, Iterable, Callable
from typing import Union, Dict, List, Iterable, Callable, Any
import numpy as np
from numpy.typing import ArrayLike
@ -21,23 +20,22 @@ This file is used for:
In this file they are defined to be used across the entire package.
"""
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
'episode']
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
MOVEMAP = defaultdict(lambda: (0, 0),
MOVEMAP = defaultdict(lambda: (0, 0),
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
}
)
@ -80,7 +78,19 @@ class ObservationTranslator:
self._this_named_obs_space = this_named_observation_space
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
def translate_observation(self, agent_idx: int, obs):
def translate_observation(self, agent_idx, obs) -> ArrayLike:
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param obs: The observation to be translated.
:type obs: ArrayLike
:return: The translated observation.
:rtype: ArrayLike
"""
target_obs_space = self._per_agent_named_obs_space[agent_idx]
translation = dict()
for name, idxs in target_obs_space.items():
@ -98,7 +108,10 @@ class ObservationTranslator:
translation = dict(sorted(translation.items()))
return np.concatenate(list(translation.values()), axis=-3)
def translate_observations(self, observations: List[ArrayLike]):
def translate_observations(self, observations) -> List[ArrayLike]:
"""
Internal Usage
"""
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
def __call__(self, observations):
@ -129,11 +142,26 @@ class ActionTranslator:
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
def translate_action(self, agent_idx: int, action: int):
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param action: The action to be translated.
:type action: int
:return: The translated action.
:rtype: ArrayLike
"""
named_action = self._per_agent_idx_actions[agent_idx][action]
translated_action = self._target_named_action_space[named_action]
return translated_action
def translate_actions(self, actions: List[int]):
"""
Intenal Usage
"""
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
def __call__(self, actions):
@ -179,6 +207,13 @@ def one_hot_level(level, symbol: str):
def is_move(action_name: str):
"""
Check if the given action name corresponds to a movement action.
:param action_name: The name of the action to check.
:type action_name: str
:return: True if the action is a movement action, False otherwise.
"""
return action_name in MOVEMAP.keys()
@ -208,7 +243,18 @@ def asset_str(agent):
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
"""Locate an object by name or dotted path, importing as necessary."""
"""
Locate an object by name or dotted path.
:param class_name: The class name to be imported
:type class_name: str
:param folder_path: The path to the module containing the class.
:type folder_path: Union[str, PurePath]
:return: The imported module class.
:raises AttributeError: If the specified class is not found in the provided folder path.
"""
import sys
sys.path.append("../../environment")
folder_path = Path(folder_path).resolve()
@ -220,15 +266,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
for module_path in module_paths:
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
all_found_modules.extend([x for x in dir(mod) if (not (x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
'Move8']])
try:
model_class = mod.__getattribute__(class_name)
return model_class
module_class = mod.__getattribute__(class_name)
return module_class
except AttributeError:
continue
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
@ -244,16 +290,33 @@ def add_pos_name(name_str, bound_e):
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> Any | None:
"""
Get the first element from an iterable that satisfies the specified condition.
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The first element that satisfies the condition, or None if none is found.
:rtype: Any
"""
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> int | None:
"""
todo
Get the index of the first element from an iterable that satisfies the specified condition.
:param iterable:
:param filter_by:
:return:
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The index of the first element that satisfies the condition, or None if none is found.
:rtype: Optional[int]
"""
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)