mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-06 01:21:36 +02:00
Merge branch 'main' into unit_testing
# Conflicts: # marl_factory_grid/modules/doors/groups.py # marl_factory_grid/utils/states.py
This commit is contained in:
@ -1,8 +1,7 @@
|
||||
import importlib
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import PurePath, Path
|
||||
from typing import Union, Dict, List, Iterable, Callable
|
||||
from typing import Union, Dict, List, Iterable, Callable, Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
@ -21,23 +20,22 @@ This file is used for:
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
LEVELS_DIR = 'levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||
[[-1, 0], [0, 0], [1, 0]],
|
||||
[[-1, 1], [0, 1], [1, 1]]])
|
||||
[[-1, 0], [0, 0], [1, 0]],
|
||||
[[-1, 1], [0, 1], [1, 1]]])
|
||||
|
||||
MOVEMAP = defaultdict(lambda: (0, 0),
|
||||
MOVEMAP = defaultdict(lambda: (0, 0),
|
||||
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
|
||||
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
|
||||
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
|
||||
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
|
||||
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
|
||||
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
|
||||
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
@ -80,7 +78,19 @@ class ObservationTranslator:
|
||||
self._this_named_obs_space = this_named_observation_space
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||
|
||||
def translate_observation(self, agent_idx: int, obs):
|
||||
def translate_observation(self, agent_idx, obs) -> ArrayLike:
|
||||
"""
|
||||
Translates the observation of the given agent.
|
||||
|
||||
:param agent_idx: Agent identifier.
|
||||
:type agent_idx: int
|
||||
|
||||
:param obs: The observation to be translated.
|
||||
:type obs: ArrayLike
|
||||
|
||||
:return: The translated observation.
|
||||
:rtype: ArrayLike
|
||||
"""
|
||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||
translation = dict()
|
||||
for name, idxs in target_obs_space.items():
|
||||
@ -98,7 +108,10 @@ class ObservationTranslator:
|
||||
translation = dict(sorted(translation.items()))
|
||||
return np.concatenate(list(translation.values()), axis=-3)
|
||||
|
||||
def translate_observations(self, observations: List[ArrayLike]):
|
||||
def translate_observations(self, observations) -> List[ArrayLike]:
|
||||
"""
|
||||
Internal Usage
|
||||
"""
|
||||
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||
|
||||
def __call__(self, observations):
|
||||
@ -129,11 +142,26 @@ class ActionTranslator:
|
||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||
|
||||
def translate_action(self, agent_idx: int, action: int):
|
||||
"""
|
||||
Translates the observation of the given agent.
|
||||
|
||||
:param agent_idx: Agent identifier.
|
||||
:type agent_idx: int
|
||||
|
||||
:param action: The action to be translated.
|
||||
:type action: int
|
||||
|
||||
:return: The translated action.
|
||||
:rtype: ArrayLike
|
||||
"""
|
||||
named_action = self._per_agent_idx_actions[agent_idx][action]
|
||||
translated_action = self._target_named_action_space[named_action]
|
||||
return translated_action
|
||||
|
||||
def translate_actions(self, actions: List[int]):
|
||||
"""
|
||||
Intenal Usage
|
||||
"""
|
||||
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
||||
|
||||
def __call__(self, actions):
|
||||
@ -179,6 +207,13 @@ def one_hot_level(level, symbol: str):
|
||||
|
||||
|
||||
def is_move(action_name: str):
|
||||
"""
|
||||
Check if the given action name corresponds to a movement action.
|
||||
|
||||
:param action_name: The name of the action to check.
|
||||
:type action_name: str
|
||||
:return: True if the action is a movement action, False otherwise.
|
||||
"""
|
||||
return action_name in MOVEMAP.keys()
|
||||
|
||||
|
||||
@ -208,7 +243,18 @@ def asset_str(agent):
|
||||
|
||||
|
||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
"""Locate an object by name or dotted path, importing as necessary."""
|
||||
"""
|
||||
Locate an object by name or dotted path.
|
||||
|
||||
:param class_name: The class name to be imported
|
||||
:type class_name: str
|
||||
|
||||
:param folder_path: The path to the module containing the class.
|
||||
:type folder_path: Union[str, PurePath]
|
||||
|
||||
:return: The imported module class.
|
||||
:raises AttributeError: If the specified class is not found in the provided folder path.
|
||||
"""
|
||||
import sys
|
||||
sys.path.append("../../environment")
|
||||
folder_path = Path(folder_path).resolve()
|
||||
@ -220,15 +266,15 @@ def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
for module_path in module_paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if (not(x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
all_found_modules.extend([x for x in dir(mod) if (not (x.startswith('__') or len(x) <= 2) and x.istitle())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent',
|
||||
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
||||
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
|
||||
'Move8']])
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
module_class = mod.__getattribute__(class_name)
|
||||
return module_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
|
||||
@ -244,9 +290,33 @@ def add_pos_name(name_str, bound_e):
|
||||
return name_str
|
||||
|
||||
|
||||
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> Any | None:
|
||||
"""
|
||||
Get the first element from an iterable that satisfies the specified condition.
|
||||
|
||||
:param iterable: The iterable to search.
|
||||
:type iterable: Iterable
|
||||
|
||||
:param filter_by: A function that filters elements, defaults to lambda _: True.
|
||||
:type filter_by: Callable[[Any], bool]
|
||||
|
||||
:return: The first element that satisfies the condition, or None if none is found.
|
||||
:rtype: Any
|
||||
"""
|
||||
return next((x for x in iterable if filter_by(x)), None)
|
||||
|
||||
|
||||
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True):
|
||||
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> int | None:
|
||||
"""
|
||||
Get the index of the first element from an iterable that satisfies the specified condition.
|
||||
|
||||
:param iterable: The iterable to search.
|
||||
:type iterable: Iterable
|
||||
|
||||
:param filter_by: A function that filters elements, defaults to lambda _: True.
|
||||
:type filter_by: Callable[[Any], bool]
|
||||
|
||||
:return: The index of the first element that satisfies the condition, or None if none is found.
|
||||
:rtype: Optional[int]
|
||||
"""
|
||||
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
|
||||
|
@ -15,9 +15,24 @@ class LevelParser(object):
|
||||
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
Internal Usage
|
||||
"""
|
||||
return self.pomdp_r * 2 + 1
|
||||
|
||||
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
|
||||
"""
|
||||
Parses a level file and creates the initial state of the environment.
|
||||
|
||||
:param level_file_path: Path to the level file.
|
||||
:type level_file_path: PathLike
|
||||
|
||||
:param entity_parse_dict: Dictionary specifying how to parse different entities.
|
||||
:type entity_parse_dict: Dict[Entities, dict]
|
||||
|
||||
:param pomdp_r: The POMDP radius. Defaults to 0.
|
||||
:type pomdp_r: int
|
||||
"""
|
||||
self.pomdp_r = pomdp_r
|
||||
self.e_p_dict = entity_parse_dict
|
||||
self._parsed_level = h.parse_level(Path(level_file_path))
|
||||
@ -25,14 +40,30 @@ class LevelParser(object):
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
|
||||
def get_coordinates_for_symbol(self, symbol, negate=False):
|
||||
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
|
||||
"""
|
||||
Get the coordinates for a given symbol in the parsed level.
|
||||
|
||||
:param symbol: The symbol to search for.
|
||||
:param negate: If True, get coordinates not matching the symbol. Defaults to False.
|
||||
|
||||
:return: Array of coordinates.
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol)
|
||||
if negate:
|
||||
return np.argwhere(level_array != c.VALUE_OCCUPIED_CELL)
|
||||
else:
|
||||
return np.argwhere(level_array == c.VALUE_OCCUPIED_CELL)
|
||||
|
||||
def do_init(self):
|
||||
def do_init(self) -> Entities:
|
||||
"""
|
||||
Initialize the environment map state by creating entities such as Walls, Agents or Machines according to the
|
||||
entity parse dict.
|
||||
|
||||
:return: A dict of all parsed entities with their positions.
|
||||
:rtype: Entities
|
||||
"""
|
||||
# Global Entities
|
||||
list_of_all_positions = ([tuple(f) for f in self.get_coordinates_for_symbol(c.SYMBOL_WALL, negate=True)])
|
||||
entities = Entities(list_of_all_positions)
|
||||
|
@ -18,12 +18,24 @@ class OBSBuilder(object):
|
||||
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.all_obs = dict()
|
||||
self.ray_caster = dict()
|
||||
|
||||
|
@ -7,6 +7,12 @@ from numba import njit
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||
|
@ -33,6 +33,12 @@ class Renderer:
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
# TODO: Customn_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
|
@ -3,13 +3,16 @@ from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.object import Object
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
TYPES = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoObject:
|
||||
"""
|
||||
Data class representing information about an entity or the global environment.
|
||||
"""
|
||||
identifier: str
|
||||
val_type: str
|
||||
value: Union[float, int]
|
||||
@ -17,6 +20,16 @@ class InfoObject:
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
"""
|
||||
A generic result class representing outcomes of operations or actions.
|
||||
|
||||
Attributes:
|
||||
- identifier: A unique identifier for the result.
|
||||
- validity: A boolean indicating whether the operation or action was successful.
|
||||
- reward: The reward associated with the result, if applicable.
|
||||
- value: The value associated with the result, if applicable.
|
||||
- entity: The entity associated with the result, if applicable.
|
||||
"""
|
||||
identifier: str
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
@ -24,6 +37,11 @@ class Result:
|
||||
entity: Object = None
|
||||
|
||||
def get_infos(self):
|
||||
"""
|
||||
Get information about the result.
|
||||
|
||||
:return: A list of InfoObject representing different types of information.
|
||||
"""
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
# Return multiple Info Dicts
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}',
|
||||
@ -38,16 +56,37 @@ class Result:
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid{reward}{value}{entity})'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickResult(Result):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionResult(Result):
|
||||
def __init__(self, *args, action_introduced_collision: bool = False, **kwargs):
|
||||
"""
|
||||
A specific Result class representing outcomes of actions.
|
||||
|
||||
:param action_introduced_collision: Wether the action did introduce a colision between agents or other entities.
|
||||
These need to be able to collide.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.action_introduced_collision = action_introduced_collision
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DoneResult(Result):
|
||||
"""
|
||||
A specific Result class representing the completion of an action or operation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class State(Result):
|
||||
# TODO: change identifier to action/last_action
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class TickResult(Result):
|
||||
"""
|
||||
A specific Result class representing outcomes of tick operations.
|
||||
"""
|
||||
pass
|
||||
|
@ -1,5 +1,4 @@
|
||||
from itertools import islice
|
||||
from itertools import islice
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -15,6 +14,12 @@ from marl_factory_grid.utils.results import Result
|
||||
|
||||
class StepRules:
|
||||
def __init__(self, *args):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
if args:
|
||||
self.rules = list(args)
|
||||
else:
|
||||
@ -80,6 +85,12 @@ class Gamestate(object):
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: List[Rule], tests: [Test], lvl_shape, env_seed=69, verbose=False):
|
||||
"""
|
||||
TODO
|
||||
|
||||
|
||||
:return:
|
||||
"""
|
||||
self.lvl_shape = lvl_shape
|
||||
self.entities = entities
|
||||
self.curr_step = 0
|
||||
|
@ -2,38 +2,68 @@ import importlib
|
||||
import inspect
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
|
||||
ACTION = 'Action'
|
||||
GENERAL = 'General'
|
||||
ENTITIES = 'Objects'
|
||||
ACTION = 'Action'
|
||||
GENERAL = 'General'
|
||||
ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
ASSETS = 'Assets'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
RULES = 'Rule'
|
||||
TESTS = 'Tests'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move', 'Agent', 'GlobalPositions', 'Walls',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
|
||||
|
||||
class ConfigExplainer:
|
||||
|
||||
def __init__(self, custom_path: Union[None, PathLike] = None):
|
||||
self.base_path = Path(__file__).parent.parent.resolve()
|
||||
self.custom_path = custom_path
|
||||
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, ASSETS]
|
||||
def __init__(self, custom_path: None | PathLike = None):
|
||||
"""
|
||||
This utility serves as a helper for debugging and exploring available modules and classes.
|
||||
Does not do anything unless told.
|
||||
The functions get_xxxxx() retrieves and returns the information and save_xxxxx() dumps them to disk.
|
||||
|
||||
def explain_module(self, class_to_explain):
|
||||
get_all() and save_all() helps geting a general overview.
|
||||
|
||||
When provided with a custom path, your own modules become available.
|
||||
|
||||
:param custom_path: Path to your custom module folder.
|
||||
"""
|
||||
self.base_path = Path(__file__).parent.parent.resolve()
|
||||
self.custom_path = Path(custom_path) if custom_path is not None else custom_path
|
||||
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, TESTS]
|
||||
|
||||
@staticmethod
|
||||
def _explain_module(class_to_explain):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
parameters = inspect.signature(class_to_explain).parameters
|
||||
explained = {class_to_explain.__name__:
|
||||
{key: val.default for key, val in parameters.items() if key not in EXCLUDED}
|
||||
}
|
||||
return explained
|
||||
|
||||
def _get_by_identifier(self, identifier):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
entities_base_cls = locate_and_import_class(identifier, self.base_path)
|
||||
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
found_entities = self._load_and_compare(entities_base_cls, module_paths)
|
||||
if self.custom_path is not None:
|
||||
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
|
||||
and '__init__' not in x.name]
|
||||
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
|
||||
return found_entities
|
||||
|
||||
def _load_and_compare(self, compare_class, paths):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
conf = {}
|
||||
package_pos = next(idx for idx, x in enumerate(Path(__file__).resolve().parts) if x == 'marl_factory_grid')
|
||||
for module_path in paths:
|
||||
@ -44,40 +74,97 @@ class ConfigExplainer:
|
||||
mod = mods.__getattribute__(key)
|
||||
try:
|
||||
if issubclass(mod, compare_class) and mod != compare_class:
|
||||
conf.update(self.explain_module(mod))
|
||||
conf.update(self._explain_module(mod))
|
||||
except TypeError:
|
||||
pass
|
||||
return conf
|
||||
|
||||
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_actions.yml'):
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
|
||||
@staticmethod
|
||||
def _save_to_file(data: dict, filepath: PathLike, tag: str = ''):
|
||||
"""
|
||||
INTERNAL USE ONLY
|
||||
"""
|
||||
filepath = Path(filepath)
|
||||
yaml.Dumper.ignore_aliases = lambda *args: True
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(data, f, encoding='utf-8')
|
||||
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
||||
print(f'See file: {filepath}')
|
||||
|
||||
def get_actions(self):
|
||||
def get_actions(self) -> list[str]:
|
||||
"""
|
||||
Retrieve all actions from module folders.
|
||||
|
||||
:returns: A list of all available actions.
|
||||
"""
|
||||
actions = self._get_by_identifier(ACTION)
|
||||
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.'
|
||||
actions = list(actions.keys())
|
||||
actions.extend([c.MOVE8, c.MOVE4])
|
||||
# TODO: Print to file!
|
||||
return actions
|
||||
|
||||
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_entities.yml'):
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
|
||||
def get_all(self) -> dict[str]:
|
||||
"""
|
||||
Retrieve all available configurations from module folders.
|
||||
|
||||
:returns: A dictionary of all available configurations.
|
||||
"""
|
||||
|
||||
config_dict = {
|
||||
'General': self.get_general_section(),
|
||||
'Agents': self.get_agent_section(),
|
||||
'Entities': self.get_entities(),
|
||||
'Rules': self.get_rules()
|
||||
}
|
||||
return config_dict
|
||||
|
||||
def get_entities(self):
|
||||
"""
|
||||
Retrieve all entities from module folders.
|
||||
|
||||
:returns: A list of all available entities.
|
||||
"""
|
||||
entities = self._get_by_identifier(ENTITIES)
|
||||
return entities
|
||||
|
||||
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_rules.yml'):
|
||||
self._save_to_file(self.get_entities(), output_conf_file, RULES)
|
||||
@staticmethod
|
||||
def get_general_section():
|
||||
"""
|
||||
Build the general section.
|
||||
|
||||
def get_rules(self):
|
||||
:returns: A list of all available entities.
|
||||
"""
|
||||
general = {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
|
||||
'pomdp_r': 3, 'individual_rewards': True, 'tests': False}
|
||||
return general
|
||||
|
||||
def get_agent_section(self):
|
||||
"""
|
||||
Build the Agent section and retrieve all available actions and observations from module folders.
|
||||
|
||||
:returns: Agent section.
|
||||
"""
|
||||
agents = dict(
|
||||
ExampleAgentName=dict(
|
||||
Actions=self.get_actions(),
|
||||
Observations=self.get_observations())),
|
||||
return agents
|
||||
|
||||
def get_rules(self) -> dict[str]:
|
||||
"""
|
||||
Retrieve all rules from module folders.
|
||||
|
||||
:returns: All available rules.
|
||||
"""
|
||||
rules = self._get_by_identifier(RULES)
|
||||
return rules
|
||||
|
||||
def get_assets(self):
|
||||
pass
|
||||
def get_observations(self) -> list[str]:
|
||||
"""
|
||||
Retrieve all agent observations from module folders.
|
||||
|
||||
def get_observations(self):
|
||||
:returns: A list of all available observations.
|
||||
"""
|
||||
names = [c.ALL, c.COMBINED, c.SELF, c.OTHERS, "Agent['ExampleAgentName']"]
|
||||
for key, val in self.get_entities().items():
|
||||
try:
|
||||
@ -95,45 +182,47 @@ class ConfigExplainer:
|
||||
names.extend(e)
|
||||
return names
|
||||
|
||||
def _get_by_identifier(self, identifier):
|
||||
entities_base_cls = locate_and_import_class(identifier, self.base_path)
|
||||
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
found_entities = self._load_and_compare(entities_base_cls, module_paths)
|
||||
if self.custom_path is not None:
|
||||
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
|
||||
and '__init__' not in x.name]
|
||||
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
|
||||
return found_entities
|
||||
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'actions.yml'):
|
||||
"""
|
||||
Write all availale actions to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/actions.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
|
||||
|
||||
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained.yml'):
|
||||
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'entities.yml'):
|
||||
"""
|
||||
Write all availale entities to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/entities.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
|
||||
|
||||
def save_observations(self, output_conf_file: PathLike = Path('../../quickstart') / 'observations.yml'):
|
||||
"""
|
||||
Write all availale observations to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/observations.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, OBSERVATIONS)
|
||||
|
||||
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'rules.yml'):
|
||||
"""
|
||||
Write all availale rules to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/rules.yml
|
||||
"""
|
||||
self._save_to_file(self.get_entities(), output_conf_file, RULES)
|
||||
|
||||
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'all.yml'):
|
||||
"""
|
||||
Write all availale keywords to a file.
|
||||
:param output_conf_file: File to write to. Defaults to ../../quickstart/all.yml
|
||||
"""
|
||||
self._save_to_file(self.get_all(), output_conf_file, 'ALL')
|
||||
|
||||
def get_all(self):
|
||||
config_dict = {GENERAL: {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
|
||||
'pomdp_r': 3, 'individual_rewards': True},
|
||||
'Agents': dict(
|
||||
ExampleAgentName=dict(
|
||||
Actions=self.get_actions(),
|
||||
Observations=self.get_observations())),
|
||||
'Entities': self.get_entities(),
|
||||
'Rules': self.get_rules(),
|
||||
'Assets': self.get_assets()}
|
||||
return config_dict
|
||||
|
||||
def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''):
|
||||
filepath = Path(filepath)
|
||||
yaml.Dumper.ignore_aliases = lambda *args: True
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(data, f, encoding='utf-8')
|
||||
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
||||
print(f'See file: {filepath}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ce = ConfigExplainer()
|
||||
ce.get_actions()
|
||||
ce.get_entities()
|
||||
ce.get_rules()
|
||||
ce.get_observations()
|
||||
ce.get_assets()
|
||||
# ce.get_actions()
|
||||
# ce.get_entities()
|
||||
# ce.get_rules()
|
||||
# ce.get_observations()
|
||||
all_conf = ce.get_all()
|
||||
ce.save_all()
|
||||
|
@ -18,6 +18,10 @@ class MarlFrameStack(gym.ObservationWrapper):
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
"""
|
||||
This class defines the interface to communicate with the Renderer. Name and pos are used to load an asset file
|
||||
named name.png and place it at the given pos.
|
||||
"""
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
@ -30,6 +34,10 @@ class RenderEntity:
|
||||
|
||||
@dataclass
|
||||
class Floor:
|
||||
"""
|
||||
This class defines Entity like Floor-Objects, which do not come with the overhead.
|
||||
Solely used for field-of-view calculation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
|
Reference in New Issue
Block a user