Documentation

This commit is contained in:
Joel Friedrich
2023-11-22 12:12:04 +01:00
committed by Steffen Illium
parent 604c0c6f57
commit 855f53b406
35 changed files with 655 additions and 198 deletions

View File

@@ -1,8 +1,7 @@
import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List, Iterable, Callable
from typing import Union, Dict, List, Iterable, Callable, Any
import numpy as np
from numpy.typing import ArrayLike
@@ -21,23 +20,22 @@ This file is used for:
In this file they are defined to be used across the entire package.
"""
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
'episode']
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
MOVEMAP = defaultdict(lambda: (0, 0),
MOVEMAP = defaultdict(lambda: (0, 0),
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
}
)
@@ -80,7 +78,19 @@ class ObservationTranslator:
self._this_named_obs_space = this_named_observation_space
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
def translate_observation(self, agent_idx: int, obs):
def translate_observation(self, agent_idx, obs) -> ArrayLike:
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param obs: The observation to be translated.
:type obs: ArrayLike
:return: The translated observation.
:rtype: ArrayLike
"""
target_obs_space = self._per_agent_named_obs_space[agent_idx]
translation = dict()
for name, idxs in target_obs_space.items():
@@ -98,7 +108,10 @@ class ObservationTranslator:
translation = dict(sorted(translation.items()))
return np.concatenate(list(translation.values()), axis=-3)
def translate_observations(self, observations: List[ArrayLike]):
def translate_observations(self, observations) -> List[ArrayLike]:
"""
Internal Usage
"""
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
def __call__(self, observations):
@@ -129,11 +142,26 @@ class ActionTranslator:
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
def translate_action(self, agent_idx: int, action: int):
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param action: The action to be translated.
:type action: int
:return: The translated action.
:rtype: ArrayLike
"""
named_action = self._per_agent_idx_actions[agent_idx][action]
translated_action = self._target_named_action_space[named_action]
return translated_action
def translate_actions(self, actions: List[int]):
"""
Intenal Usage
"""
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
def __call__(self, actions):
@@ -179,6 +207,13 @@ def one_hot_level(level, symbol: str):
def is_move(action_name: str):
"""
Check if the given action name corresponds to a movement action.
:param action_name: The name of the action to check.
:type action_name: str
:return: True if the action is a movement action, False otherwise.
"""
return action_name in MOVEMAP.keys()
@@ -208,7 +243,18 @@ def asset_str(agent):
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
"""Locate an object by name or dotted path, importing as necessary."""
"""
Locate an object by name or dotted path.
:param class_name: The class name to be imported
:type class_name: str
:param folder_path: The path to the module containing the class.
:type folder_path: Union[str, PurePath]
:return: The imported module class.
:raises AttributeError: If the specified class is not found in the provided folder path.
"""
import sys
sys.path.append("../../environment")
folder_path = Path(folder_path).resolve()
@@ -220,15 +266,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
for module_path in module_paths:
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
all_found_modules.extend([x for x in dir(mod) if (not (x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
'Move8']])
try:
model_class = mod.__getattribute__(class_name)
return model_class
module_class = mod.__getattribute__(class_name)
return module_class
except AttributeError:
continue
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
@@ -244,16 +290,33 @@ def add_pos_name(name_str, bound_e):
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> Any | None:
"""
Get the first element from an iterable that satisfies the specified condition.
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The first element that satisfies the condition, or None if none is found.
:rtype: Any
"""
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> int | None:
"""
todo
Get the index of the first element from an iterable that satisfies the specified condition.
:param iterable:
:param filter_by:
:return:
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The index of the first element that satisfies the condition, or None if none is found.
:rtype: Optional[int]
"""
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@@ -15,9 +15,24 @@ class LevelParser(object):
@property
def pomdp_d(self):
"""
Internal Usage
"""
return self.pomdp_r * 2 + 1
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
"""
Parses a level file and creates the initial state of the environment.
:param level_file_path: Path to the level file.
:type level_file_path: PathLike
:param entity_parse_dict: Dictionary specifying how to parse different entities.
:type entity_parse_dict: Dict[Entities, dict]
:param pomdp_r: The POMDP radius. Defaults to 0.
:type pomdp_r: int
"""
self.pomdp_r = pomdp_r
self.e_p_dict = entity_parse_dict
self._parsed_level = h.parse_level(Path(level_file_path))
@@ -25,14 +40,30 @@ class LevelParser(object):
self.level_shape = level_array.shape
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
def get_coordinates_for_symbol(self, symbol, negate=False):
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
"""
Get the coordinates for a given symbol in the parsed level.
:param symbol: The symbol to search for.
:param negate: If True, get coordinates not matching the symbol. Defaults to False.
:return: Array of coordinates.
:rtype: np.ndarray
"""
level_array = h.one_hot_level(self._parsed_level, symbol)
if negate:
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
else:
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
def do_init(self):
def do_init(self) -> Entities:
"""
Initialize the environment map state by creating entities such as Walls, Agents or Machines according to the
entity parse dict.
:return: A dict of all parsed entities with their positions.
:rtype: Entities
"""
# Global Entities
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
entities = Entities(list_of_all_positions)

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value'
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
TYPES = [TYPE_VALUE, TYPE_REWARD]
@@ -11,10 +11,7 @@ TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass
class InfoObject:
"""
TODO
:return:
Data class representing information about an entity or the global environment.
"""
identifier: str
val_type: str
@@ -24,10 +21,14 @@ class InfoObject:
@dataclass
class Result:
"""
TODO
A generic result class representing outcomes of operations or actions.
:return:
Attributes:
- identifier: A unique identifier for the result.
- validity: A boolean indicating whether the operation or action was successful.
- reward: The reward associated with the result, if applicable.
- value: The value associated with the result, if applicable.
- entity: The entity associated with the result, if applicable.
"""
identifier: str
validity: bool
@@ -36,6 +37,11 @@ class Result:
entity: Object = None
def get_infos(self):
"""
Get information about the result.
:return: A list of InfoObject representing different types of information.
"""
n = self.entity.name if self.entity is not None else "Global"
# Return multiple Info Dicts
return [InfoObject(identifier=f'{n}_{self.identifier}',
@@ -50,32 +56,30 @@ class Result:
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass
class TickResult(Result):
"""
TODO
"""
pass
@dataclass
class ActionResult(Result):
"""
TODO
A specific Result class representing outcomes of actions.
"""
pass
@dataclass
class ActionResult(Result):
pass
@dataclass
class State(Result):
# TODO: change identifiert to action/last_action
pass
@dataclass
class DoneResult(Result):
"""
A specific Result class representing the completion of an action or operation.
"""
pass
@dataclass
class State(Result):
# TODO: change identifier to action/last_action
pass
@dataclass
class TickResult(Result):
"""
A specific Result class representing outcomes of tick operations.
"""
pass