mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 14:56:43 +02:00
293 lines
12 KiB
Python
293 lines
12 KiB
Python
import importlib
|
|
from collections import defaultdict
|
|
from pathlib import PurePath, Path
|
|
from typing import Union, Dict, List, Iterable, Callable, Any
|
|
|
|
import numpy as np
|
|
from numpy.typing import ArrayLike
|
|
|
|
from marl_factory_grid.environment import constants as c
|
|
|
|
"""
|
|
This file is used for:
|
|
1. string based definition
|
|
Use a class like `Constants`, to define attributes, which then reveal strings.
|
|
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
|
When defining new envs, use class inheritance.
|
|
|
|
2. utility function definition
|
|
There are static utility functions which are not bound to a specific environment.
|
|
In this file they are defined to be used across the entire package.
|
|
"""
|
|
|
|
LEVELS_DIR = 'levels' # for use in studies and experiments
|
|
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
|
|
|
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
|
'train_step', 'step', 'index', 'dirt_amount', 'dirty_pos_count', 'terminal_observation',
|
|
'episode']
|
|
|
|
POS_MASK_8 = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
|
[[-1, 0], [0, 0], [1, 0]],
|
|
[[-1, 1], [0, 1], [1, 1]]])
|
|
|
|
POS_MASK_4 = np.asarray([[0, -1], [-1, 0], [1, 0], [-1, 1], [0, 1], [1, 1]])
|
|
|
|
MOVEMAP = defaultdict(lambda: (0, 0),
|
|
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
|
|
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
|
|
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
|
|
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
|
|
}
|
|
)
|
|
|
|
|
|
class ObservationTranslator:
|
|
|
|
def __init__(self, this_named_observation_space: Dict[str, dict],
|
|
*per_agent_named_obs_spaces: Dict[str, dict],
|
|
placeholder_fill_value: Union[int, str, None] = None):
|
|
"""
|
|
This is a helper class, which converts agent observations from joined environments.
|
|
For example, agent trained in different environments may expect different observations.
|
|
This class translates from larger observations spaces to smaller.
|
|
A string _identifier based approach is used.
|
|
Currently, it is not possible to mix different obs shapes.
|
|
|
|
:param this_named_observation_space: `Named observation space` of the joined environment.
|
|
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
|
:param placeholder_fill_value: Currently, not fully implemented!!!
|
|
"""
|
|
|
|
if isinstance(placeholder_fill_value, str):
|
|
if placeholder_fill_value.lower() in ['normal', 'n']:
|
|
self.random_fill = np.random.normal
|
|
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
|
self.random_fill = np.random.uniform
|
|
else:
|
|
raise ValueError('Please chooe between "uniform" or "normal" ("u", "n").')
|
|
elif isinstance(placeholder_fill_value, int):
|
|
raise NotImplementedError('"Future Work."')
|
|
else:
|
|
self.random_fill = None
|
|
|
|
self._this_named_obs_space = this_named_observation_space
|
|
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
|
|
|
def translate_observation(self, agent_idx, obs) -> ArrayLike:
|
|
"""
|
|
Translates the observation of the given agent.
|
|
|
|
:param agent_idx: Agent identifier.
|
|
:type agent_idx: int
|
|
|
|
:param obs: The observation to be translated.
|
|
:type obs: ArrayLike
|
|
|
|
:return: The translated observation.
|
|
:rtype: ArrayLike
|
|
"""
|
|
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
|
translation = dict()
|
|
for name, idxs in target_obs_space.items():
|
|
if name in self._this_named_obs_space:
|
|
for target_idx, this_idx in zip(idxs, self._this_named_obs_space[name]):
|
|
taken_slice = np.take(obs, [this_idx], axis=1 if obs.ndim == 4 else 0)
|
|
translation[target_idx] = taken_slice
|
|
elif random_fill := self.random_fill:
|
|
for target_idx in idxs:
|
|
translation[target_idx] = random_fill(size=obs.shape[:-3] + (1,) + obs.shape[-2:])
|
|
else:
|
|
for target_idx in idxs:
|
|
translation[target_idx] = np.zeros(shape=(obs.shape[:-3] + (1,) + obs.shape[-2:]))
|
|
|
|
translation = dict(sorted(translation.items()))
|
|
return np.concatenate(list(translation.values()), axis=-3)
|
|
|
|
def translate_observations(self, observations) -> List[ArrayLike]:
|
|
"""
|
|
Internal Usage
|
|
"""
|
|
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
|
|
|
def __call__(self, observations):
|
|
return self.translate_observations(observations)
|
|
|
|
|
|
class ActionTranslator:
|
|
|
|
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
|
"""
|
|
This is a helper class, which converts agent action spaces to a joined environments action space.
|
|
For example, agent trained in different environments may have different action spaces.
|
|
This class translates from smaller individual agent action spaces to larger joined spaces.
|
|
A string _identifier based approach is used.
|
|
|
|
:param target_named_action_space: Joined `Named action space` for the current environment.
|
|
:type target_named_action_space: Dict[str, dict]
|
|
|
|
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
|
|
:type per_agent_named_action_space: Dict[str, dict]
|
|
"""
|
|
|
|
self._target_named_action_space = target_named_action_space
|
|
if isinstance(per_agent_named_action_space, (list, tuple)):
|
|
self._per_agent_named_action_space = per_agent_named_action_space
|
|
else:
|
|
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
|
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
|
|
|
def translate_action(self, agent_idx: int, action: int):
|
|
"""
|
|
Translates the observation of the given agent.
|
|
|
|
:param agent_idx: Agent identifier.
|
|
:type agent_idx: int
|
|
|
|
:param action: The action to be translated.
|
|
:type action: int
|
|
|
|
:return: The translated action.
|
|
:rtype: ArrayLike
|
|
"""
|
|
named_action = self._per_agent_idx_actions[agent_idx][action]
|
|
translated_action = self._target_named_action_space[named_action]
|
|
return translated_action
|
|
|
|
def translate_actions(self, actions: List[int]):
|
|
"""
|
|
Intenal Usage
|
|
"""
|
|
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
|
|
|
def __call__(self, actions):
|
|
return self.translate_actions(actions)
|
|
|
|
|
|
# Utility functions
|
|
def parse_level(path):
|
|
"""
|
|
Given the path to a strin based `level` or `map` representation, this function reads the content.
|
|
Cleans `space`, checks for equal length of each row and returns a list of lists.
|
|
|
|
:param path: Path to the `level` or `map` file on harddrive.
|
|
:type path: os.Pathlike
|
|
|
|
:return: The read string representation of the `level` or `map`
|
|
:rtype: List[List[str]]
|
|
"""
|
|
with path.open('r') as lvl:
|
|
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
|
if len(set([len(line) for line in level])) > 1:
|
|
raise AssertionError('Every row of the level string must be of equal length.')
|
|
return level
|
|
|
|
|
|
def one_hot_level(level, symbol: str):
|
|
"""
|
|
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
|
|
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
|
|
Can be changed to filter for any symbol.
|
|
|
|
:param level: String based level representation (list of lists, see function `parse_level`).
|
|
:param symbol: List[List[str]]
|
|
|
|
:return: Binary numpy array
|
|
:rtype: np.typing._array_like.ArrayLike
|
|
"""
|
|
|
|
grid = np.array(level)
|
|
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
|
binary_grid[grid == str(symbol)] = c.VALUE_OCCUPIED_CELL
|
|
return binary_grid
|
|
|
|
|
|
def is_move(action_name: str):
|
|
"""
|
|
Check if the given action name corresponds to a movement action.
|
|
|
|
:param action_name: The name of the action to check.
|
|
:type action_name: str
|
|
:return: True if the action is a movement action, False otherwise.
|
|
"""
|
|
return action_name in MOVEMAP.keys()
|
|
|
|
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
|
"""
|
|
Locate an object by name or dotted path.
|
|
|
|
:param class_name: The class name to be imported
|
|
:type class_name: str
|
|
|
|
:param folder_path: The path to the module containing the class.
|
|
:type folder_path: Union[str, PurePath]
|
|
|
|
:return: The imported module class.
|
|
:raises AttributeError: If the specified class is not found in the provided folder path.
|
|
"""
|
|
import sys
|
|
sys.path.append("../../environment")
|
|
folder_path = Path(folder_path).resolve()
|
|
module_paths = [x.resolve() for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
|
# possible_package_path = folder_path / '__init__.py'
|
|
# package = str(possible_package_path) if possible_package_path.exists() else None
|
|
all_found_modules = list()
|
|
package_pos = next(idx for idx, x in enumerate(Path(__file__).resolve().parts) if x == 'marl_factory_grid')
|
|
for module_path in module_paths:
|
|
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
|
mod = importlib.import_module('.'.join(module_parts))
|
|
all_found_modules.extend([x for x in dir(mod) if (not (x.startswith('__') or len(x) <= 2) and x.istitle())
|
|
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union',
|
|
'TickResult', 'ActionResult', 'Action', 'Agent',
|
|
'RenderEntity', 'TemplateRule', 'Objects', 'PositionMixin',
|
|
'IsBoundMixin', 'EnvObject', 'EnvObjects', 'Dict', 'Any', 'Factory',
|
|
'Move8']])
|
|
try:
|
|
module_class = mod.__getattribute__(class_name)
|
|
return module_class
|
|
except AttributeError:
|
|
continue
|
|
raise AttributeError(f'Class "{class_name}" was not found in "{folder_path.name}"', list(set(all_found_modules)))
|
|
|
|
|
|
def add_bound_name(name_str, bound_e):
|
|
return f'{name_str}({bound_e.name})'
|
|
|
|
|
|
def add_pos_name(name_str, bound_e):
|
|
if bound_e.var_has_position:
|
|
return f'{name_str}@{bound_e.pos}'
|
|
return name_str
|
|
|
|
|
|
def get_first(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> Any | None:
|
|
"""
|
|
Get the first element from an iterable that satisfies the specified condition.
|
|
|
|
:param iterable: The iterable to search.
|
|
:type iterable: Iterable
|
|
|
|
:param filter_by: A function that filters elements, defaults to lambda _: True.
|
|
:type filter_by: Callable[[Any], bool]
|
|
|
|
:return: The first element that satisfies the condition, or None if none is found.
|
|
:rtype: Any
|
|
"""
|
|
return next((x for x in iterable if filter_by(x)), None)
|
|
|
|
|
|
def get_first_index(iterable: Iterable, filter_by: Callable[[any], bool] = lambda _: True) -> int | None:
|
|
"""
|
|
Get the index of the first element from an iterable that satisfies the specified condition.
|
|
|
|
:param iterable: The iterable to search.
|
|
:type iterable: Iterable
|
|
|
|
:param filter_by: A function that filters elements, defaults to lambda _: True.
|
|
:type filter_by: Callable[[Any], bool]
|
|
|
|
:return: The index of the first element that satisfies the condition, or None if none is found.
|
|
:rtype: Optional[int]
|
|
"""
|
|
return next((idx for idx, x in enumerate(iterable) if filter_by(x)), None)
|