mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-04 08:31:35 +02:00
Documentation
This commit is contained in:

committed by
Steffen Illium

parent
604c0c6f57
commit
855f53b406
@ -1,8 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import PurePath, Path
|
||||
from typing import Union, Dict, List, Iterable, Callable
|
||||
from typing import Union, Dict, List, Iterable, Callable, Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
@ -21,23 +20,22 @@ This file is used for:
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||
[[-1, 0], [0, 0], [1, 0]],
|
||||
[[-1, 1], [0, 1], [1, 1]]])
|
||||
[[-1, 0], [0, 0], [1, 0]],
|
||||
[[-1, 1], [0, 1], [1, 1]]])
|
||||
|
||||
MOVEMAP = defaultdict(lambda: (0, 0),
|
||||
MOVEMAP = defaultdict(lambda: (0, 0),
|
||||
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
|
||||
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
|
||||
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
|
||||
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
|
||||
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
|
||||
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
|
||||
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
@ -80,7 +78,19 @@ class ObservationTranslator:
|
||||
self._this_named_obs_space = this_named_observation_space
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||
|
||||
def translate_observation(self, agent_idx: int, obs):
|
||||
def translate_observation(self, agent_idx, obs) -> ArrayLike:
|
||||
"""
|
||||
Translates the observation of the given agent.
|
||||
|
||||
:param agent_idx: Agent identifier.
|
||||
:type agent_idx: int
|
||||
|
||||
:param obs: The observation to be translated.
|
||||
:type obs: ArrayLike
|
||||
|
||||
:return: The translated observation.
|
||||
:rtype: ArrayLike
|
||||
"""
|
||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||
translation = dict()
|
||||
for name, idxs in target_obs_space.items():
|
||||
@ -98,7 +108,10 @@ class ObservationTranslator:
|
||||
translation = dict(sorted(translation.items()))
|
||||
return np.concatenate(list(translation.values()), axis=-3)
|
||||
|
||||
def translate_observations(self, observations: List[ArrayLike]):
|
||||
def translate_observations(self, observations) -> List[ArrayLike]:
|
||||
"""
|
||||
Internal Usage
|
||||
"""
|
||||
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||
|
||||
def __call__(self, observations):
|
||||
@ -129,11 +142,26 @@ class ActionTranslator:
|
||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||
|
||||
def translate_action(self, agent_idx: int, action: int):
|
||||
"""
|
||||
Translates the observation of the given agent.
|
||||
|
||||
:param agent_idx: Agent identifier.
|
||||
:type agent_idx: int
|
||||
|
||||
:param action: The action to be translated.
|
||||
:type action: int
|
||||
|
||||
:return: The translated action.
|
||||
:rtype: ArrayLike
|
||||
"""
|
||||
named_action = self._per_agent_idx_actions[agent_idx][action]
|
||||
translated_action = self._target_named_action_space[named_action]
|
||||
return translated_action
|
||||
|
||||
def translate_actions(self, actions: List[int]):
|
||||
"""
|
||||
Intenal Usage
|
||||
"""
|
||||
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
||||
|
||||
def __call__(self, actions):
|
||||
@ -179,6 +207,13 @@ def one_hot_level(level, symbol: str):
|
||||
|
||||
|
||||
def is_move(action_name: str):
|
||||
"""
|
||||
Check if the given action name corresponds to a movement action.
|
||||
|
||||
:param action_name: The name of the action to check.
|
||||
:type action_name: str
|
||||
:return: True if the action is a movement action, False otherwise.
|
||||
"""
|
||||
return action_name in MOVEMAP.keys()
|
||||
|
||||
|
||||
@ -208,7 +243,18 @@ def asset_str(agent):
|
||||
|
||||
|
||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
"""Locate an object by name or dotted path, importing as necessary."""
|
||||
"""
|
||||
Locate an object by name or dotted path.
|
||||
|
||||
:param class_name: The class name to be imported
|
||||
:type class_name: str
|
||||
|
||||
:param folder_path: The path to the module containing the class.
|
||||
:type folder_path: Union[str, PurePath]
|
||||
|
||||
:return: The imported module class.
|
||||
:raises AttributeError: If the specified class is not found in the provided folder path.
|
||||
"""
|
||||
import sys
|
||||
sys.path.append("../../environment")
|
||||
folder_path = Path(folder_path).resolve()
|
||||
@ -220,15 +266,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
for module_path in module_paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
all_found_modules.extend([x for x in dir(mod) if (not (x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
|
||||
'Move8']])
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
module_class = mod.__getattribute__(class_name)
|
||||
return module_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
|
||||
@ -244,16 +290,33 @@ def add_pos_name(name_str, bound_e):
|
||||
return name_str
|
||||
|
||||
|
||||
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> Any | None:
|
||||
"""
|
||||
Get the first element from an iterable that satisfies the specified condition.
|
||||
|
||||
:param iterable: The iterable to search.
|
||||
:type iterable: Iterable
|
||||
|
||||
:param filter_by: A function that filters elements, defaults to lambda _: True.
|
||||
:type filter_by: Callable[[Any], bool]
|
||||
|
||||
:return: The first element that satisfies the condition, or None if none is found.
|
||||
:rtype: Any
|
||||
"""
|
||||
return next((x for x in iterable if filter_by(x)), None)
|
||||
|
||||
|
||||
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> int | None:
|
||||
"""
|
||||
todo
|
||||
Get the index of the first element from an iterable that satisfies the specified condition.
|
||||
|
||||
:param iterable:
|
||||
:param filter_by:
|
||||
:return:
|
||||
:param iterable: The iterable to search.
|
||||
:type iterable: Iterable
|
||||
|
||||
:param filter_by: A function that filters elements, defaults to lambda _: True.
|
||||
:type filter_by: Callable[[Any], bool]
|
||||
|
||||
:return: The index of the first element that satisfies the condition, or None if none is found.
|
||||
:rtype: Optional[int]
|
||||
"""
|
||||
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
|
||||
|
Reference in New Issue
Block a user