Merge branch 'main' into unit_testing

# Conflicts:
#	marl_factory_grid/modules/doors/groups.py
#	marl_factory_grid/utils/states.py
This commit is contained in:
Chanumask
2023-11-23 12:58:12 +01:00
63 changed files with 1477 additions and 330 deletions

View File

@ -1,8 +1,7 @@
import importlib
from collections import defaultdict
from pathlib import PurePath, Path
from typing import Union, Dict, List, Iterable, Callable
from typing import Union, Dict, List, Iterable, Callable, Any
import numpy as np
from numpy.typing import ArrayLike
@ -21,23 +20,22 @@ This file is used for:
In this file they are defined to be used across the entire package.
"""
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
LEVELS_DIR = 'levels' # for use in studies and experiments
STEPS_START = 1 # Define where to the stepcount; which is the first step
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
'episode']
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
[[-1, 0], [0, 0], [1, 0]],
[[-1, 1], [0, 1], [1, 1]]])
MOVEMAP = defaultdict(lambda: (0, 0),
MOVEMAP = defaultdict(lambda: (0, 0),
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
}
)
@ -80,7 +78,19 @@ class ObservationTranslator:
self._this_named_obs_space = this_named_observation_space
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
def translate_observation(self, agent_idx: int, obs):
def translate_observation(self, agent_idx, obs) -> ArrayLike:
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param obs: The observation to be translated.
:type obs: ArrayLike
:return: The translated observation.
:rtype: ArrayLike
"""
target_obs_space = self._per_agent_named_obs_space[agent_idx]
translation = dict()
for name, idxs in target_obs_space.items():
@ -98,7 +108,10 @@ class ObservationTranslator:
translation = dict(sorted(translation.items()))
return np.concatenate(list(translation.values()), axis=-3)
def translate_observations(self, observations: List[ArrayLike]):
def translate_observations(self, observations) -> List[ArrayLike]:
"""
Internal Usage
"""
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
def __call__(self, observations):
@ -129,11 +142,26 @@ class ActionTranslator:
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
def translate_action(self, agent_idx: int, action: int):
"""
Translates the observation of the given agent.
:param agent_idx: Agent identifier.
:type agent_idx: int
:param action: The action to be translated.
:type action: int
:return: The translated action.
:rtype: ArrayLike
"""
named_action = self._per_agent_idx_actions[agent_idx][action]
translated_action = self._target_named_action_space[named_action]
return translated_action
def translate_actions(self, actions: List[int]):
"""
Intenal Usage
"""
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
def __call__(self, actions):
@ -179,6 +207,13 @@ def one_hot_level(level, symbol: str):
def is_move(action_name: str):
"""
Check if the given action name corresponds to a movement action.
:param action_name: The name of the action to check.
:type action_name: str
:return: True if the action is a movement action, False otherwise.
"""
return action_name in MOVEMAP.keys()
@ -208,7 +243,18 @@ def asset_str(agent):
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
"""Locate an object by name or dotted path, importing as necessary."""
"""
Locate an object by name or dotted path.
:param class_name: The class name to be imported
:type class_name: str
:param folder_path: The path to the module containing the class.
:type folder_path: Union[str, PurePath]
:return: The imported module class.
:raises AttributeError: If the specified class is not found in the provided folder path.
"""
import sys
sys.path.append("../../environment")
folder_path = Path(folder_path).resolve()
@ -220,15 +266,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
for module_path in module_paths:
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
mod = importlib.import_module('.'.join(module_parts))
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
all_found_modules.extend([x for x in dir(mod) if (not (x.startswith('__') or len(x) <= 2) and x.istitle())
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
'TickResult', 'ActionResult', 'Action', 'Agent',
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
'Move8']])
try:
model_class = mod.__getattribute__(class_name)
return model_class
module_class = mod.__getattribute__(class_name)
return module_class
except AttributeError:
continue
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
@ -244,9 +290,33 @@ def add_pos_name(name_str, bound_e):
return name_str
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> Any | None:
"""
Get the first element from an iterable that satisfies the specified condition.
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The first element that satisfies the condition, or None if none is found.
:rtype: Any
"""
return next((x for x in iterable if filter_by(x)), None)
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> int | None:
"""
Get the index of the first element from an iterable that satisfies the specified condition.
:param iterable: The iterable to search.
:type iterable: Iterable
:param filter_by: A function that filters elements, defaults to lambda _: True.
:type filter_by: Callable[[Any], bool]
:return: The index of the first element that satisfies the condition, or None if none is found.
:rtype: Optional[int]
"""
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)

View File

@ -15,9 +15,24 @@ class LevelParser(object):
@property
def pomdp_d(self):
"""
Internal Usage
"""
return self.pomdp_r * 2 + 1
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
"""
Parses a level file and creates the initial state of the environment.
:param level_file_path: Path to the level file.
:type level_file_path: PathLike
:param entity_parse_dict: Dictionary specifying how to parse different entities.
:type entity_parse_dict: Dict[Entities, dict]
:param pomdp_r: The POMDP radius. Defaults to 0.
:type pomdp_r: int
"""
self.pomdp_r = pomdp_r
self.e_p_dict = entity_parse_dict
self._parsed_level = h.parse_level(Path(level_file_path))
@ -25,14 +40,30 @@ class LevelParser(object):
self.level_shape = level_array.shape
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
def get_coordinates_for_symbol(self, symbol, negate=False):
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
"""
Get the coordinates for a given symbol in the parsed level.
:param symbol: The symbol to search for.
:param negate: If True, get coordinates not matching the symbol. Defaults to False.
:return: Array of coordinates.
:rtype: np.ndarray
"""
level_array = h.one_hot_level(self._parsed_level, symbol)
if negate:
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
else:
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
def do_init(self):
def do_init(self) -> Entities:
"""
Initialize the environment map state by creating entities such as Walls, Agents or Machines according to the
entity parse dict.
:return: A dict of all parsed entities with their positions.
:rtype: Entities
"""
# Global Entities
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
entities = Entities(list_of_all_positions)

View File

@ -18,12 +18,24 @@ class OBSBuilder(object):
@property
def pomdp_d(self):
"""
TODO
:return:
"""
if self.pomdp_r:
return (self.pomdp_r * 2) + 1
else:
return 0
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
"""
TODO
:return:
"""
self.all_obs = dict()
self.ray_caster = dict()

View File

@ -7,6 +7,12 @@ from numba import njit
class RayCaster:
def __init__(self, agent, pomdp_r, degs=360):
"""
TODO
:return:
"""
self.agent = agent
self.pomdp_r = pomdp_r
self.n_rays = 100 # (self.pomdp_r + 1) * 8

View File

@ -33,6 +33,12 @@ class Renderer:
lvl_padded_shape: Union[Tuple[int, int], None] = None,
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
grid_lines: bool = True, view_radius: int = 2):
"""
TODO
:return:
"""
# TODO: Customn_assets paths
self.grid_h, self.grid_w = lvl_shape
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape

View File

@ -3,13 +3,16 @@ from dataclasses import dataclass
from marl_factory_grid.environment.entity.object import Object
TYPE_VALUE = 'value'
TYPE_VALUE = 'value'
TYPE_REWARD = 'reward'
TYPES = [TYPE_VALUE, TYPE_REWARD]
@dataclass
class InfoObject:
"""
Data class representing information about an entity or the global environment.
"""
identifier: str
val_type: str
value: Union[float, int]
@ -17,6 +20,16 @@ class InfoObject:
@dataclass
class Result:
"""
A generic result class representing outcomes of operations or actions.
Attributes:
- identifier: A unique identifier for the result.
- validity: A boolean indicating whether the operation or action was successful.
- reward: The reward associated with the result, if applicable.
- value: The value associated with the result, if applicable.
- entity: The entity associated with the result, if applicable.
"""
identifier: str
validity: bool
reward: Union[float, None] = None
@ -24,6 +37,11 @@ class Result:
entity: Object = None
def get_infos(self):
"""
Get information about the result.
:return: A list of InfoObject representing different types of information.
"""
n = self.entity.name if self.entity is not None else "Global"
# Return multiple Info Dicts
return [InfoObject(identifier=f'{n}_{self.identifier}',
@ -38,16 +56,37 @@ class Result:
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
@dataclass
class TickResult(Result):
pass
@dataclass
class ActionResult(Result):
def __init__(self, *args, action_introduced_collision: bool = False, **kwargs):
"""
A specific Result class representing outcomes of actions.
:param action_introduced_collision: Wether the action did introduce a colision between agents or other entities.
These need to be able to collide.
"""
super().__init__(*args, **kwargs)
self.action_introduced_collision = action_introduced_collision
pass
@dataclass
class DoneResult(Result):
"""
A specific Result class representing the completion of an action or operation.
"""
pass
@dataclass
class State(Result):
# TODO: change identifier to action/last_action
pass
@dataclass
class TickResult(Result):
"""
A specific Result class representing outcomes of tick operations.
"""
pass

View File

@ -1,5 +1,4 @@
from itertools import islice
from itertools import islice
from typing import List, Tuple
import numpy as np
@ -15,6 +14,12 @@ from marl_factory_grid.utils.results import Result
class StepRules:
def __init__(self, *args):
"""
TODO
:return:
"""
if args:
self.rules = list(args)
else:
@ -80,6 +85,12 @@ class Gamestate(object):
return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
"""
TODO
:return:
"""
self.lvl_shape = lvl_shape
self.entities = entities
self.curr_step = 0

View File

@ -2,38 +2,68 @@ import importlib
import inspect
from os import PathLike
from pathlib import Path
from typing import Union
import yaml
from marl_factory_grid.environment import constants as c
from marl_factory_grid.utils.helpers import locate_and_import_class
ACTION = 'Action'
GENERAL = 'General'
ENTITIES = 'Objects'
ACTION = 'Action'
GENERAL = 'General'
ENTITIES = 'Objects'
OBSERVATIONS = 'Observations'
RULES = 'Rule'
ASSETS = 'Assets'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
RULES = 'Rule'
TESTS = 'Tests'
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
class ConfigExplainer:
def __init__(self, custom_path: Union[None, PathLike] = None):
self.base_path = Path(__file__).parent.parent.resolve()
self.custom_path = custom_path
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, ASSETS]
def __init__(self, custom_path: None | PathLike = None):
"""
This utility serves as a helper for debugging and exploring available modules and classes.
Does not do anything unless told.
The functions get_xxxxx() retrieves and returns the information and save_xxxxx() dumps them to disk.
def explain_module(self, class_to_explain):
get_all() and save_all() helps geting a general overview.
When provided with a custom path, your own modules become available.
:param custom_path: Path to your custom module folder.
"""
self.base_path = Path(__file__).parent.parent.resolve()
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
@staticmethod
def _explain_module(class_to_explain):
"""
INTERNAL USE ONLY
"""
parameters = inspect.signature(class_to_explain).parameters
explained = {class_to_explain.__name__:
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
}
return explained
def _get_by_identifier(self, identifier):
"""
INTERNAL USE ONLY
"""
entities_base_cls = locate_and_import_class(identifier, self.base_path)
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, module_paths)
if self.custom_path is not None:
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
and '__init__' not in x.name]
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
return found_entities
def _load_and_compare(self, compare_class, paths):
"""
INTERNAL USE ONLY
"""
conf = {}
package_pos = next(idx for idx, x in enumerate(Path(__file__).resolve().parts) if x == 'marl_factory_grid')
for module_path in paths:
@ -44,40 +74,97 @@ class ConfigExplainer:
mod = mods.__getattribute__(key)
try:
if issubclass(mod, compare_class) and mod != compare_class:
conf.update(self.explain_module(mod))
conf.update(self._explain_module(mod))
except TypeError:
pass
return conf
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_actions.yml'):
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
@staticmethod
def _save_to_file(data: dict, filepath: PathLike, tag: str = ''):
"""
INTERNAL USE ONLY
"""
filepath = Path(filepath)
yaml.Dumper.ignore_aliases = lambda *args: True
with filepath.open('w') as f:
yaml.dump(data, f, encoding='utf-8')
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
print(f'See file: {filepath}')
def get_actions(self):
def get_actions(self) -> list[str]:
"""
Retrieve all actions from module folders.
:returns: A list of all available actions.
"""
actions = self._get_by_identifier(ACTION)
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.'
actions = list(actions.keys())
actions.extend([c.MOVE8, c.MOVE4])
# TODO: Print to file!
return actions
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_entities.yml'):
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
def get_all(self) -> dict[str]:
"""
Retrieve all available configurations from module folders.
:returns: A dictionary of all available configurations.
"""
config_dict = {
'General': self.get_general_section(),
'Agents': self.get_agent_section(),
'Entities': self.get_entities(),
'Rules': self.get_rules()
}
return config_dict
def get_entities(self):
"""
Retrieve all entities from module folders.
:returns: A list of all available entities.
"""
entities = self._get_by_identifier(ENTITIES)
return entities
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_rules.yml'):
self._save_to_file(self.get_entities(), output_conf_file, RULES)
@staticmethod
def get_general_section():
"""
Build the general section.
def get_rules(self):
:returns: A list of all available entities.
"""
general = {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
'pomdp_r': 3, 'individual_rewards': True, 'tests': False}
return general
def get_agent_section(self):
"""
Build the Agent section and retrieve all available actions and observations from module folders.
:returns: Agent section.
"""
agents = dict(
ExampleAgentName=dict(
Actions=self.get_actions(),
Observations=self.get_observations())),
return agents
def get_rules(self) -> dict[str]:
"""
Retrieve all rules from module folders.
:returns: All available rules.
"""
rules = self._get_by_identifier(RULES)
return rules
def get_assets(self):
pass
def get_observations(self) -> list[str]:
"""
Retrieve all agent observations from module folders.
def get_observations(self):
:returns: A list of all available observations.
"""
names = [c.ALL, c.COMBINED, c.SELF, c.OTHERS, "Agent['ExampleAgentName']"]
for key, val in self.get_entities().items():
try:
@ -95,45 +182,47 @@ class ConfigExplainer:
names.extend(e)
return names
def _get_by_identifier(self, identifier):
entities_base_cls = locate_and_import_class(identifier, self.base_path)
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
found_entities = self._load_and_compare(entities_base_cls, module_paths)
if self.custom_path is not None:
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
and '__init__' not in x.name]
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
return found_entities
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'actions.yml'):
"""
Write all availale actions to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/actions.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained.yml'):
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'entities.yml'):
"""
Write all availale entities to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/entities.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
def save_observations(self, output_conf_file: PathLike = Path('../../quickstart') / 'observations.yml'):
"""
Write all availale observations to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/observations.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, OBSERVATIONS)
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'rules.yml'):
"""
Write all availale rules to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/rules.yml
"""
self._save_to_file(self.get_entities(), output_conf_file, RULES)
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'all.yml'):
"""
Write all availale keywords to a file.
:param output_conf_file: File to write to. Defaults to ../../quickstart/all.yml
"""
self._save_to_file(self.get_all(), output_conf_file, 'ALL')
def get_all(self):
config_dict = {GENERAL: {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
'pomdp_r': 3, 'individual_rewards': True},
'Agents': dict(
ExampleAgentName=dict(
Actions=self.get_actions(),
Observations=self.get_observations())),
'Entities': self.get_entities(),
'Rules': self.get_rules(),
'Assets': self.get_assets()}
return config_dict
def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''):
filepath = Path(filepath)
yaml.Dumper.ignore_aliases = lambda *args: True
with filepath.open('w') as f:
yaml.dump(data, f, encoding='utf-8')
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
print(f'See file: {filepath}')
if __name__ == '__main__':
ce = ConfigExplainer()
ce.get_actions()
ce.get_entities()
ce.get_rules()
ce.get_observations()
ce.get_assets()
# ce.get_actions()
# ce.get_entities()
# ce.get_rules()
# ce.get_observations()
all_conf = ce.get_all()
ce.save_all()

View File

@ -18,6 +18,10 @@ class MarlFrameStack(gym.ObservationWrapper):
@dataclass
class RenderEntity:
"""
This class defines the interface to communicate with the Renderer. Name and pos are used to load an asset file
named name.png and place it at the given pos.
"""
name: str
pos: np.array
value: float = 1
@ -30,6 +34,10 @@ class RenderEntity:
@dataclass
class Floor:
"""
This class defines Entity like Floor-Objects, which do not come with the overhead.
Solely used for field-of-view calculation.
"""
@property
def encoding(self):