mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-12 23:52:42 +02:00
renaming
This commit is contained in:
README.md
marl_factory_grid
__init__.py
algorithms
environment
__init__.pyactions.py
assets
constants.pyentity
factory.pygroups
rewards.pyrules.pylogging
modules
__init__.py
_template
batteries
clean_up
__init__.pyactions.pyconstants.pydirtpiles.pngentitites.pygroups.pyrewards.pyrule_done_on_all_clean.pyrule_respawn.pyrule_smear_on_move.py
destinations
doors
__init__.pyactions.pyconstants.pydoor_closed.pngdoor_open.pngentitites.pygroups.pyrewards.pyrule_door_auto_close.py
items
levels
machines
plotting
quickstart
utils
mfg_package/algorithms/marl
reload_agent.pysetup.py
0
marl_factory_grid/utils/__init__.py
Normal file
0
marl_factory_grid/utils/__init__.py
Normal file
135
marl_factory_grid/utils/config_parser.py
Normal file
135
marl_factory_grid/utils/config_parser.py
Normal file
@ -0,0 +1,135 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid.environment.groups.agents import Agents
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
DEFAULT_PATH = 'environment'
|
||||
MODULE_PATH = 'modules'
|
||||
|
||||
|
||||
class FactoryConfigParser(object):
|
||||
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENTS]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(config_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
self.do_record = False
|
||||
|
||||
def __getattr__(self, item):
|
||||
return self['General'][item]
|
||||
|
||||
def _get_sub_list(self, primary_key: str, sub_key: str):
|
||||
return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items()
|
||||
} for x in self.config[primary_key]]
|
||||
|
||||
@property
|
||||
def agent_actions(self):
|
||||
return self._get_sub_list('Agents', "Actions")
|
||||
|
||||
@property
|
||||
def agent_observations(self):
|
||||
return self._get_sub_list('Agents', "Observations")
|
||||
|
||||
@property
|
||||
def rules(self):
|
||||
return self.config['Rules']
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return self.config['Agents']
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
return self.config['Entities']
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.config)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config[item]
|
||||
|
||||
def load_entities(self):
|
||||
# entites = Entities()
|
||||
entity_classes = dict()
|
||||
entities = []
|
||||
if c.DEFAULTS in self.entities:
|
||||
entities.extend(self.default_entites)
|
||||
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
||||
|
||||
for entity in entities:
|
||||
try:
|
||||
folder_path = MODULE_PATH if entity not in self.default_entites else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError:
|
||||
folder_path = self.custom_modules_path
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
||||
return entity_classes
|
||||
|
||||
def load_agents(self, size, free_tiles):
|
||||
agents = Agents(size)
|
||||
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
||||
for name in self.agents:
|
||||
# Actions
|
||||
actions = list()
|
||||
if c.DEFAULTS in self.agents[name]['Actions']:
|
||||
actions.extend(self.default_actions)
|
||||
actions.extend(x for x in self.agents[name]['Actions'] if x != c.DEFAULTS)
|
||||
parsed_actions = list()
|
||||
for action in actions:
|
||||
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
try:
|
||||
class_or_classes = locate_and_import_class(action, folder_path)
|
||||
except AttributeError:
|
||||
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
|
||||
try:
|
||||
parsed_actions.extend(class_or_classes)
|
||||
except TypeError:
|
||||
parsed_actions.append(class_or_classes)
|
||||
|
||||
parsed_actions = [x() for x in parsed_actions]
|
||||
|
||||
# Observation
|
||||
observations = list()
|
||||
if c.DEFAULTS in self.agents[name]['Observations']:
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
agent = Agent(parsed_actions, observations, free_tiles.pop(), str_ident=name)
|
||||
agents.add_item(agent)
|
||||
return agents
|
||||
|
||||
def load_rules(self):
|
||||
# entites = Entities()
|
||||
rules_classes = dict()
|
||||
rules = []
|
||||
if c.DEFAULTS in self.rules:
|
||||
for rule in self.default_rules:
|
||||
if rule not in rules:
|
||||
rules.append(rule)
|
||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||
|
||||
for rule in rules:
|
||||
folder_path = MODULE_PATH if rule not in self.default_rules else DEFAULT_PATH
|
||||
folder_path = (Path(__file__) / '..' / '..' / '..' / folder_path)
|
||||
try:
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
rule_kwargs = self.rules.get(rule, {})
|
||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||
return rules_classes
|
239
marl_factory_grid/utils/helpers.py
Normal file
239
marl_factory_grid/utils/helpers.py
Normal file
@ -0,0 +1,239 @@
|
||||
import importlib
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import PurePath, Path
|
||||
from typing import Union, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
"""
|
||||
This file is used for:
|
||||
1. string based definition
|
||||
Use a class like `Constants`, to define attributes, which then reveal strings.
|
||||
These can be used for naming convention along the environments as well as keys for mappings such as dicts etc.
|
||||
When defining new envs, use class inheritance.
|
||||
|
||||
2. utility function definition
|
||||
There are static utility functions which are not bound to a specific environment.
|
||||
In this file they are defined to be used across the entire package.
|
||||
"""
|
||||
|
||||
|
||||
LEVELS_DIR = 'modules/levels' # for use in studies and experiments
|
||||
STEPS_START = 1 # Define where to the stepcount; which is the first step
|
||||
|
||||
# Not used anymore? Clean!
|
||||
# TO_BE_AVERAGED = ['dirt_amount', 'dirty_tiles']
|
||||
IGNORED_DF_COLUMNS = ['Episode', 'Run', # For plotting, which values are ignored when loading monitor files
|
||||
'train_step', 'step', 'index', 'dirt_amount', 'dirty_tile_count', 'terminal_observation',
|
||||
'episode']
|
||||
|
||||
POS_MASK = np.asarray([[[-1, -1], [0, -1], [1, -1]],
|
||||
[[-1, 0], [0, 0], [1, 0]],
|
||||
[[-1, 1], [0, 1], [1, 1]]])
|
||||
|
||||
MOVEMAP = defaultdict(lambda: (0, 0),
|
||||
{c.NORTH: (-1, 0), c.NORTHEAST: (-1, 1),
|
||||
c.EAST: (0, 1), c.SOUTHEAST: (1, 1),
|
||||
c.SOUTH: (1, 0), c.SOUTHWEST: (1, -1),
|
||||
c.WEST: (0, -1), c.NORTHWEST: (-1, -1)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ObservationTranslator:
|
||||
|
||||
def __init__(self, this_named_observation_space: Dict[str, dict],
|
||||
*per_agent_named_obs_spaces: Dict[str, dict],
|
||||
placeholder_fill_value: Union[int, str, None] = None):
|
||||
"""
|
||||
This is a helper class, which converts agent observations from joined environments.
|
||||
For example, agent trained in different environments may expect different observations.
|
||||
This class translates from larger observations spaces to smaller.
|
||||
A string _identifier based approach is used.
|
||||
Currently, it is not possible to mix different obs shapes.
|
||||
|
||||
|
||||
:param this_named_observation_space: `Named observation space` of the joined environment.
|
||||
:type this_named_observation_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_obs_spaces: `Named observation space` one for each agent. Overloaded.
|
||||
type per_agent_named_obs_spaces: Dict[str, dict]
|
||||
|
||||
:param placeholder_fill_value: Currently not fully implemented!!!
|
||||
:type placeholder_fill_value: Union[int, str] = 'N')
|
||||
"""
|
||||
|
||||
if isinstance(placeholder_fill_value, str):
|
||||
if placeholder_fill_value.lower() in ['normal', 'n']:
|
||||
self.random_fill = np.random.normal
|
||||
elif placeholder_fill_value.lower() in ['uniform', 'u']:
|
||||
self.random_fill = np.random.uniform
|
||||
else:
|
||||
raise ValueError('Please chooe between "uniform" or "normal" ("u", "n").')
|
||||
elif isinstance(placeholder_fill_value, int):
|
||||
raise NotImplementedError('"Future Work."')
|
||||
else:
|
||||
self.random_fill = None
|
||||
|
||||
self._this_named_obs_space = this_named_observation_space
|
||||
self._per_agent_named_obs_space = list(per_agent_named_obs_spaces)
|
||||
|
||||
def translate_observation(self, agent_idx: int, obs):
|
||||
target_obs_space = self._per_agent_named_obs_space[agent_idx]
|
||||
translation = dict()
|
||||
for name, idxs in target_obs_space.items():
|
||||
if name in self._this_named_obs_space:
|
||||
for target_idx, this_idx in zip(idxs, self._this_named_obs_space[name]):
|
||||
taken_slice = np.take(obs, [this_idx], axis=1 if obs.ndim == 4 else 0)
|
||||
translation[target_idx] = taken_slice
|
||||
elif random_fill := self.random_fill:
|
||||
for target_idx in idxs:
|
||||
translation[target_idx] = random_fill(size=obs.shape[:-3] + (1,) + obs.shape[-2:])
|
||||
else:
|
||||
for target_idx in idxs:
|
||||
translation[target_idx] = np.zeros(shape=(obs.shape[:-3] + (1,) + obs.shape[-2:]))
|
||||
|
||||
translation = dict(sorted(translation.items()))
|
||||
return np.concatenate(list(translation.values()), axis=-3)
|
||||
|
||||
def translate_observations(self, observations: List[ArrayLike]):
|
||||
return [self.translate_observation(idx, observation) for idx, observation in enumerate(observations)]
|
||||
|
||||
def __call__(self, observations):
|
||||
return self.translate_observations(observations)
|
||||
|
||||
|
||||
class ActionTranslator:
|
||||
|
||||
def __init__(self, target_named_action_space: Dict[str, int], *per_agent_named_action_space: Dict[str, int]):
|
||||
"""
|
||||
This is a helper class, which converts agent action spaces to a joined environments action space.
|
||||
For example, agent trained in different environments may have different action spaces.
|
||||
This class translates from smaller individual agent action spaces to larger joined spaces.
|
||||
A string _identifier based approach is used.
|
||||
|
||||
:param target_named_action_space: Joined `Named action space` for the current environment.
|
||||
:type target_named_action_space: Dict[str, dict]
|
||||
|
||||
:param per_agent_named_action_space: `Named action space` one for each agent. Overloaded.
|
||||
:type per_agent_named_action_space: Dict[str, dict]
|
||||
"""
|
||||
|
||||
self._target_named_action_space = target_named_action_space
|
||||
if isinstance(per_agent_named_action_space, (list, tuple)):
|
||||
self._per_agent_named_action_space = per_agent_named_action_space
|
||||
else:
|
||||
self._per_agent_named_action_space = list(per_agent_named_action_space)
|
||||
self._per_agent_idx_actions = [{idx: a for a, idx in x.items()} for x in self._per_agent_named_action_space]
|
||||
|
||||
def translate_action(self, agent_idx: int, action: int):
|
||||
named_action = self._per_agent_idx_actions[agent_idx][action]
|
||||
translated_action = self._target_named_action_space[named_action]
|
||||
return translated_action
|
||||
|
||||
def translate_actions(self, actions: List[int]):
|
||||
return [self.translate_action(idx, action) for idx, action in enumerate(actions)]
|
||||
|
||||
def __call__(self, actions):
|
||||
return self.translate_actions(actions)
|
||||
|
||||
|
||||
# Utility functions
|
||||
def parse_level(path):
|
||||
"""
|
||||
Given the path to a strin based `level` or `map` representation, this function reads the content.
|
||||
Cleans `space`, checks for equal length of each row and returns a list of lists.
|
||||
|
||||
:param path: Path to the `level` or `map` file on harddrive.
|
||||
:type path: os.Pathlike
|
||||
|
||||
:return: The read string representation of the `level` or `map`
|
||||
:rtype: List[List[str]]
|
||||
"""
|
||||
with path.open('r') as lvl:
|
||||
level = list(map(lambda x: list(x.strip()), lvl.readlines()))
|
||||
if len(set([len(line) for line in level])) > 1:
|
||||
raise AssertionError('Every row of the level string must be of equal length.')
|
||||
return level
|
||||
|
||||
|
||||
def one_hot_level(level, symbol: str):
|
||||
"""
|
||||
Given a string based level representation (list of lists, see function `parse_level`), this function creates a
|
||||
binary numpy array or `grid`. Grid values that equal `wall_char` become of `Constants.OCCUPIED_CELL` value.
|
||||
Can be changed to filter for any symbol.
|
||||
|
||||
:param level: String based level representation (list of lists, see function `parse_level`).
|
||||
:param symbol: List[List[str]]
|
||||
|
||||
:return: Binary numpy array
|
||||
:rtype: np.typing._array_like.ArrayLike
|
||||
"""
|
||||
|
||||
grid = np.array(level)
|
||||
binary_grid = np.zeros(grid.shape, dtype=np.int8)
|
||||
binary_grid[grid == symbol] = c.VALUE_OCCUPIED_CELL
|
||||
return binary_grid
|
||||
|
||||
|
||||
def is_move(action_name: str):
|
||||
return action_name in MOVEMAP.keys()
|
||||
|
||||
|
||||
def asset_str(agent):
|
||||
"""
|
||||
FIXME @ romue
|
||||
"""
|
||||
# What does this abonimation do?
|
||||
# if any([x is None for x in [cls._slices[j] for j in agent.collisions]]):
|
||||
# print('error')
|
||||
if step_result := agent.step_result:
|
||||
action = step_result['action_name']
|
||||
valid = step_result['action_valid']
|
||||
col_names = [x.name for x in step_result['collisions']]
|
||||
if any(c.AGENT in name for name in col_names):
|
||||
return 'agent_collision', 'blank'
|
||||
elif not valid or c.LEVEL in col_names or c.AGENT in col_names:
|
||||
return c.AGENT, 'invalid'
|
||||
elif valid and not is_move(action):
|
||||
return c.AGENT, 'valid'
|
||||
elif valid and is_move(action):
|
||||
return c.AGENT, 'move'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
else:
|
||||
return c.AGENT, 'idle'
|
||||
|
||||
|
||||
def locate_and_import_class(class_name, folder_path: Union[str, PurePath] = ''):
|
||||
"""Locate an object by name or dotted path, importing as necessary."""
|
||||
import sys
|
||||
sys.path.append("../../environment")
|
||||
folder_path = Path(folder_path).resolve()
|
||||
module_paths = [x.resolve() for x in folder_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
# possible_package_path = folder_path / '__init__.py'
|
||||
# package = str(possible_package_path) if possible_package_path.exists() else None
|
||||
all_found_modules = list()
|
||||
package_pos = next(idx for idx, x in enumerate(Path(__file__).resolve().parts) if x == 'marl_factory_grid')
|
||||
for module_path in module_paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mod = importlib.import_module('.'.join(module_parts))
|
||||
all_found_modules.extend([x for x in dir(mod) if not(x.startswith('__') or len(x) < 2 or x.isupper())
|
||||
and x not in ['Entity', 'NamedTuple', 'List', 'Rule', 'Union', 'random', 'Floor'
|
||||
'TickResult', 'ActionResult', 'Action', 'Agent', 'deque',
|
||||
'BoundEntityMixin', 'RenderEntity', 'TemplateRule', 'defaultdict',
|
||||
'is_move', 'Objects', 'PositionMixin', 'IsBoundMixin', 'EnvObject',
|
||||
'EnvObjects', 'Dict', 'locate_and_import_class', 'yaml', 'Any',
|
||||
'inspect']])
|
||||
try:
|
||||
model_class = mod.__getattribute__(class_name)
|
||||
return model_class
|
||||
except AttributeError:
|
||||
continue
|
||||
raise AttributeError(f'Class "{class_name}" was not found!!!"\n'
|
||||
f'Check the {folder_path.name} name.\n'
|
||||
f'Possible Options are:\n{set(all_found_modules)}')
|
55
marl_factory_grid/utils/level_parser.py
Normal file
55
marl_factory_grid/utils/level_parser.py
Normal file
@ -0,0 +1,55 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.groups.global_entities import Entities
|
||||
from marl_factory_grid.environment.groups.wall_n_floors import Walls, Floors
|
||||
from marl_factory_grid.utils import helpers as h
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class LevelParser(object):
|
||||
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
return self.pomdp_r * 2 + 1
|
||||
|
||||
def __init__(self, level_file_path: PathLike, entity_parse_dict: Dict[Entities, dict], pomdp_r=0):
|
||||
self.pomdp_r = pomdp_r
|
||||
self.e_p_dict = entity_parse_dict
|
||||
self._parsed_level = h.parse_level(Path(level_file_path))
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r**2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
|
||||
def do_init(self):
|
||||
entities = Entities()
|
||||
# Walls
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
|
||||
walls = Walls.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL), self.size)
|
||||
entities.add_items({c.WALL: walls})
|
||||
|
||||
# Floor
|
||||
floor = Floors.from_coordinates(np.argwhere(level_array == c.VALUE_FREE_CELL), self.size)
|
||||
entities.add_items({c.FLOOR: floor})
|
||||
|
||||
# All other
|
||||
for es_name in self.e_p_dict:
|
||||
e_class, e_kwargs = self.e_p_dict[es_name]['class'], self.e_p_dict[es_name]['kwargs']
|
||||
|
||||
if hasattr(e_class, 'symbol'):
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=e_class.symbol)
|
||||
if np.any(level_array):
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
entities[c.FLOOR], self.size, entity_kwargs=e_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||
f'Check your level file!')
|
||||
else:
|
||||
e = e_class(self.size, **e_kwargs)
|
||||
entities.add_items({e.name: e})
|
||||
return entities
|
317
marl_factory_grid/utils/observation_builder.py
Normal file
317
marl_factory_grid/utils/observation_builder.py
Normal file
@ -0,0 +1,317 @@
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from itertools import product
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
from numba import njit
|
||||
|
||||
from marl_factory_grid.environment.groups.utils import Combined
|
||||
from marl_factory_grid.utils.states import Gamestate
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class OBSBuilder(object):
|
||||
|
||||
default_obs = [c.WALLS, c.OTHERS]
|
||||
|
||||
@property
|
||||
def pomdp_d(self):
|
||||
if self.pomdp_r:
|
||||
return (self.pomdp_r * 2) + 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def __init__(self, level_shape: np.size, state: Gamestate, pomdp_r: int):
|
||||
self.all_obs = dict()
|
||||
self.light_blockers = defaultdict(lambda: False)
|
||||
self.positional = defaultdict(lambda: False)
|
||||
self.non_positional = defaultdict(lambda: False)
|
||||
self.ray_caster = dict()
|
||||
|
||||
self.level_shape = level_shape
|
||||
self.pomdp_r = pomdp_r
|
||||
self.obs_shape = (self.pomdp_d, self.pomdp_d) if self.pomdp_r else self.level_shape
|
||||
self.size = np.prod(self.obs_shape)
|
||||
|
||||
self.obs_layers = dict()
|
||||
|
||||
self.build_structured_obs_block(state)
|
||||
self.curr_lightmaps = dict()
|
||||
|
||||
def build_structured_obs_block(self, state):
|
||||
self.all_obs[c.PLACEHOLDER] = np.full(self.obs_shape, 0, dtype=float)
|
||||
self.all_obs.update({key: obj for key, obj in state.entities.obs_pairs})
|
||||
|
||||
def observation_space(self, state):
|
||||
from gymnasium.spaces import Tuple, Box
|
||||
obsn = self.refresh_and_build_for_all(state)
|
||||
if len(state[c.AGENT]) == 1:
|
||||
space = Box(low=0, high=1, shape=next(x for x in obsn.values()).shape, dtype=np.float32)
|
||||
else:
|
||||
space = Tuple([Box(low=0, high=1, shape=obs.shape, dtype=np.float32) for obs in obsn.values()])
|
||||
return space
|
||||
|
||||
def named_observation_space(self, state):
|
||||
return self.refresh_and_build_for_all(state)
|
||||
|
||||
def refresh_and_build_for_all(self, state) -> (dict, dict):
|
||||
self.build_structured_obs_block(state)
|
||||
info = {}
|
||||
return {agent.name: self.build_for_agent(agent, state)[0] for agent in state[c.AGENT]}, info
|
||||
|
||||
def refresh_and_build_named_for_all(self, state) -> Dict[str, Dict[str, np.ndarray]]:
|
||||
self.build_structured_obs_block(state)
|
||||
named_obs_dict = {}
|
||||
for agent in state[c.AGENT]:
|
||||
obs, names = self.build_for_agent(agent, state)
|
||||
named_obs_dict[agent.name] = {'observation': obs, 'names': names}
|
||||
return named_obs_dict
|
||||
|
||||
def build_for_agent(self, agent, state) -> (List[str], np.ndarray):
|
||||
try:
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
except KeyError:
|
||||
self._sort_and_name_observation_conf(agent)
|
||||
agent_want_obs = self.obs_layers[agent.name]
|
||||
|
||||
# Handle in-grid observations aka visible observations
|
||||
visible_entitites = self.ray_caster[agent.name].visible_entities(state.entities)
|
||||
pre_sort_obs = defaultdict(lambda: np.zeros((self.pomdp_d, self.pomdp_d)))
|
||||
for e in set(visible_entitites):
|
||||
x, y = (e.x - agent.x) + self.pomdp_r, (e.y - agent.y) + self.pomdp_r
|
||||
try:
|
||||
pre_sort_obs[e.obs_tag][x, y] += e.encoding
|
||||
except IndexError:
|
||||
# Seemded to be visible but is out or range
|
||||
pass
|
||||
|
||||
pre_sort_obs = dict(pre_sort_obs)
|
||||
obs = np.zeros((len(agent_want_obs), self.pomdp_d, self.pomdp_d))
|
||||
|
||||
for idx, l_name in enumerate(agent_want_obs):
|
||||
try:
|
||||
obs[idx] = pre_sort_obs[l_name]
|
||||
except KeyError:
|
||||
if c.COMBINED in l_name:
|
||||
if combined := [pre_sort_obs[x] for x in self.all_obs[f'{c.COMBINED}({agent.name})'].names
|
||||
if x in pre_sort_obs]:
|
||||
obs[idx] = np.sum(combined, axis=0)
|
||||
elif l_name == c.PLACEHOLDER:
|
||||
obs[idx] = self.all_obs[c.PLACEHOLDER]
|
||||
else:
|
||||
try:
|
||||
e = self.all_obs[l_name]
|
||||
except KeyError:
|
||||
try:
|
||||
e = self.all_obs[f'{l_name}({agent.name})']
|
||||
except KeyError:
|
||||
try:
|
||||
e = next(x for x in self.all_obs if l_name in x and agent.name in x)
|
||||
except StopIteration:
|
||||
raise KeyError(
|
||||
f'Check typing!\n{l_name} could not be found in:\n{dict(self.all_obs).keys()}')
|
||||
|
||||
try:
|
||||
positional = e.has_position
|
||||
except AttributeError:
|
||||
positional = False
|
||||
if positional:
|
||||
# Seems to be not visible, so just skip it
|
||||
# obs[idx] = np.zeros((self.pomdp_d, self.pomdp_d))
|
||||
# All good
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
v = e.encodings
|
||||
except AttributeError:
|
||||
try:
|
||||
v = e.encoding
|
||||
except AttributeError:
|
||||
raise AttributeError(f'This env. expects Entity-Clases to report their "encoding"')
|
||||
try:
|
||||
np.put(obs[idx], range(len(v)), v, mode='raise')
|
||||
except TypeError:
|
||||
np.put(obs[idx], 0, v, mode='raise')
|
||||
except IndexError:
|
||||
raise ValueError(f'Max(obs.size) for {e.name}: {obs[idx].size}, but was: {len(v)}.')
|
||||
|
||||
try:
|
||||
self.curr_lightmaps[agent.name] = pre_sort_obs[c.FLOORS].astype(bool)
|
||||
except KeyError:
|
||||
print()
|
||||
return obs, self.obs_layers[agent.name]
|
||||
|
||||
def _sort_and_name_observation_conf(self, agent):
|
||||
self.ray_caster[agent.name] = RayCaster(agent, self.pomdp_r)
|
||||
obs_layers = []
|
||||
|
||||
for obs_str in agent.observations:
|
||||
if isinstance(obs_str, dict):
|
||||
obs_str, vals = next(obs_str.items().__iter__())
|
||||
else:
|
||||
vals = None
|
||||
if obs_str == c.SELF:
|
||||
obs_layers.append(agent.name)
|
||||
elif obs_str == c.DEFAULTS:
|
||||
obs_layers.extend(self.default_obs)
|
||||
elif obs_str == c.COMBINED:
|
||||
if isinstance(vals, str):
|
||||
vals = [vals]
|
||||
names = list()
|
||||
for val in vals:
|
||||
if val == c.SELF:
|
||||
names.append(agent.name)
|
||||
elif val == c.OTHERS:
|
||||
names.extend([x.name for x in agent.collection if x.name != agent.name])
|
||||
else:
|
||||
names.append(val)
|
||||
combined = Combined(names, self.pomdp_r, identifier=agent.name)
|
||||
self.all_obs[combined.name] = combined
|
||||
obs_layers.append(combined.name)
|
||||
elif obs_str == c.OTHERS:
|
||||
obs_layers.extend([x for x in self.all_obs if x != agent.name and x.startswith(f'{c.AGENT}[')])
|
||||
elif obs_str == c.AGENTS:
|
||||
obs_layers.extend([x for x in self.all_obs if x.startswith(f'{c.AGENT}[')])
|
||||
else:
|
||||
obs_layers.append(obs_str)
|
||||
self.obs_layers[agent.name] = obs_layers
|
||||
self.curr_lightmaps[agent.name] = np.zeros((self.pomdp_d or self.level_shape[0],
|
||||
self.pomdp_d or self.level_shape[1]
|
||||
))
|
||||
|
||||
|
||||
class RayCaster:
|
||||
def __init__(self, agent, pomdp_r, degs=360):
|
||||
self.agent = agent
|
||||
self.pomdp_r = pomdp_r
|
||||
self.n_rays = 100 # (self.pomdp_r + 1) * 8
|
||||
self.degs = degs
|
||||
self.ray_targets = self.build_ray_targets()
|
||||
self.obs_shape_cube = np.array([self.pomdp_r, self.pomdp_r])
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.agent.name})'
|
||||
|
||||
def build_ray_targets(self):
|
||||
north = np.array([0, -1])*self.pomdp_r
|
||||
thetas = [np.deg2rad(deg) for deg in np.linspace(-self.degs // 2, self.degs // 2, self.n_rays)[::-1]]
|
||||
rot_M = [
|
||||
[[math.cos(theta), -math.sin(theta)],
|
||||
[math.sin(theta), math.cos(theta)]] for theta in thetas
|
||||
]
|
||||
rot_M = np.stack(rot_M, 0)
|
||||
rot_M = np.unique(np.round(rot_M @ north), axis=0)
|
||||
return rot_M.astype(int)
|
||||
|
||||
def ray_block_cache(self, cache_dict, key, callback, ents):
|
||||
if key not in cache_dict:
|
||||
cache_dict[key] = callback()
|
||||
return cache_dict[key]
|
||||
|
||||
def visible_entities(self, entities):
|
||||
visible = list()
|
||||
cache_blocking = {}
|
||||
|
||||
for ray in self.get_rays():
|
||||
rx, ry = ray[0]
|
||||
for x, y in ray:
|
||||
cx, cy = x - rx, y - ry
|
||||
|
||||
entities_hit = entities.pos_dict[(x, y)]
|
||||
hits = self.ray_block_cache(cache_blocking,
|
||||
(x, y),
|
||||
lambda: any(e.is_blocking_light for e in entities_hit),
|
||||
entities)
|
||||
|
||||
try:
|
||||
d = next(x for x in entities_hit if 'Door' in x.name)
|
||||
if d.pos in entities.pos_dict.keys():
|
||||
if d.is_closed and not entities.pos_dict[d.pos]:
|
||||
print()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
diag_hits = any([
|
||||
self.ray_block_cache(
|
||||
cache_blocking,
|
||||
key,
|
||||
# lambda: all(False for e in entities.pos_dict[key] if not e.is_blocking_light),
|
||||
lambda: any(e.is_blocking_light for e in entities.pos_dict[key]),
|
||||
entities)
|
||||
for key in ((x, y-cy), (x-cx, y))
|
||||
]) if (cx != 0 and cy != 0) else False
|
||||
|
||||
visible += entities_hit if not diag_hits else []
|
||||
if hits or diag_hits:
|
||||
break
|
||||
rx, ry = x, y
|
||||
return visible
|
||||
|
||||
def get_rays(self):
|
||||
a_pos = self.agent.pos
|
||||
outline = self.ray_targets + a_pos
|
||||
return self.bresenham_loop(a_pos, outline)
|
||||
|
||||
# todo do this once and cache the points!
|
||||
def get_fov_outline(self) -> np.ndarray:
|
||||
return self.ray_targets + self.agent.pos
|
||||
|
||||
def get_square_outline(self):
|
||||
agent = self.agent
|
||||
x_coords = range(agent.x - self.pomdp_r, agent.x + self.pomdp_r + 1)
|
||||
y_coords = range(agent.y - self.pomdp_r, agent.y + self.pomdp_r + 1)
|
||||
outline = list(product(x_coords, [agent.y - self.pomdp_r, agent.y + self.pomdp_r])) \
|
||||
+ list(product([agent.x - self.pomdp_r, agent.x + self.pomdp_r], y_coords))
|
||||
return outline
|
||||
|
||||
@staticmethod
|
||||
@njit
|
||||
def bresenham_loop(a_pos, points):
|
||||
results = []
|
||||
for end in points:
|
||||
x1, y1 = a_pos
|
||||
x2, y2 = end
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Determine how steep the line is
|
||||
is_steep = abs(dy) > abs(dx)
|
||||
|
||||
# Rotate line
|
||||
if is_steep:
|
||||
x1, y1 = y1, x1
|
||||
x2, y2 = y2, x2
|
||||
|
||||
# Swap start and end points if necessary and store swap state
|
||||
swapped = False
|
||||
if x1 > x2:
|
||||
x1, x2 = x2, x1
|
||||
y1, y2 = y2, y1
|
||||
swapped = True
|
||||
|
||||
# Recalculate differentials
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
|
||||
# Calculate error
|
||||
error = int(dx / 2.0)
|
||||
ystep = 1 if y1 < y2 else -1
|
||||
|
||||
# Iterate over bounding box generating points between start and end
|
||||
y = y1
|
||||
points = []
|
||||
for x in range(int(x1), int(x2) + 1):
|
||||
coord = [y, x] if is_steep else [x, y]
|
||||
points.append(coord)
|
||||
error -= abs(dy)
|
||||
if error < 0:
|
||||
y += ystep
|
||||
error += dx
|
||||
|
||||
# Reverse the list if the coordinates were swapped
|
||||
if swapped:
|
||||
points.reverse()
|
||||
results.append(points)
|
||||
return results
|
16
marl_factory_grid/utils/render.py
Normal file
16
marl_factory_grid/utils/render.py
Normal file
@ -0,0 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class RenderEntity:
|
||||
name: str
|
||||
pos: np.array
|
||||
value: float = 1
|
||||
value_operation: str = 'none'
|
||||
state: str = None
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
144
marl_factory_grid/utils/renderer.py
Normal file
144
marl_factory_grid/utils/renderer.py
Normal file
@ -0,0 +1,144 @@
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import pygame
|
||||
from typing import Tuple, Union
|
||||
import time
|
||||
|
||||
from marl_factory_grid.utils.render import RenderEntity
|
||||
|
||||
AGENT: str = 'agent'
|
||||
STATE_IDLE: str = 'idle'
|
||||
STATE_MOVE: str = 'move'
|
||||
STATE_VALID: str = 'valid'
|
||||
STATE_INVALID: str = 'invalid'
|
||||
STATE_COLLISION: str = 'agent_collision'
|
||||
BLANK: str = 'blank'
|
||||
DOOR: str = 'door'
|
||||
OPACITY: str = 'opacity'
|
||||
SCALE: str = 'scale'
|
||||
|
||||
|
||||
class Renderer:
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent.parent / 'assets'
|
||||
MODULE_ASSETS = Path(__file__).parent.parent.parent / 'modules'
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
self.fps = fps
|
||||
self.grid_lines = grid_lines
|
||||
self.view_radius = view_radius
|
||||
pygame.init()
|
||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png')) + list(self.MODULE_ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), 1) for path in assets}
|
||||
self.fill_bg()
|
||||
|
||||
now = time.time()
|
||||
self.font = pygame.font.Font(None, 20)
|
||||
self.font.set_bold(True)
|
||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
|
||||
def fill_bg(self):
|
||||
self.screen.fill(Renderer.BG_COLOR)
|
||||
if self.grid_lines:
|
||||
w, h = self.screen_size
|
||||
for x in range(0, w, self.cell_size):
|
||||
for y in range(0, h, self.cell_size):
|
||||
rect = pygame.Rect(x, y, self.cell_size, self.cell_size)
|
||||
pygame.draw.rect(self.screen, Renderer.WHITE, rect, 1)
|
||||
|
||||
def blit_params(self, entity):
|
||||
offset_r, offset_c = (self.lvl_padded_shape[0] - self.grid_h) // 2, \
|
||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||
|
||||
r, c = entity.pos
|
||||
r, c = r - offset_r, c-offset_c
|
||||
|
||||
img = self.assets[entity.name.lower()]
|
||||
if entity.value_operation == OPACITY:
|
||||
img.set_alpha(255*entity.value)
|
||||
elif entity.value_operation == SCALE:
|
||||
re = img.get_rect()
|
||||
img = pygame.transform.smoothscale(
|
||||
img, (int(entity.value*re.width), int(entity.value*re.height))
|
||||
)
|
||||
o = self.cell_size//2
|
||||
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
||||
rect = img.get_rect()
|
||||
rect.centerx, rect.centery = c_, r_
|
||||
return dict(source=img, dest=rect)
|
||||
|
||||
def load_asset(self, path, factor=1.0):
|
||||
s = int(factor*self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
|
||||
def visibility_rects(self, bp, view):
|
||||
rects = []
|
||||
for i, j in product(range(-self.view_radius, self.view_radius+1),
|
||||
range(-self.view_radius, self.view_radius+1)):
|
||||
if view is not None:
|
||||
if bool(view[self.view_radius+j, self.view_radius+i]):
|
||||
visibility_rect = bp['dest'].copy()
|
||||
visibility_rect.centerx += i*self.cell_size
|
||||
visibility_rect.centery += j*self.cell_size
|
||||
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
||||
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
|
||||
shape_surf.set_alpha(64)
|
||||
rects.append(dict(source=shape_surf, dest=visibility_rect))
|
||||
return rects
|
||||
|
||||
def render(self, entities):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
self.fill_bg()
|
||||
blits = deque()
|
||||
for entity in [x for x in entities]:
|
||||
bp = self.blit_params(entity)
|
||||
blits.append(bp)
|
||||
if entity.name.lower() == AGENT:
|
||||
if self.view_radius > 0:
|
||||
vis_rects = self.visibility_rects(bp, entity.aux)
|
||||
blits.extendleft(vis_rects)
|
||||
if entity.state != BLANK:
|
||||
agent_state_blits = self.blit_params(
|
||||
RenderEntity(entity.state, (entity.pos[0] + 0.12, entity.pos[1]), 0.48, SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(entity.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(bp['dest'].center[0]-.07*self.cell_size,
|
||||
bp['dest'].center[1]))
|
||||
blits += [agent_state_blits, text_blit]
|
||||
|
||||
for blit in blits:
|
||||
self.screen.blit(**blit)
|
||||
|
||||
pygame.display.flip()
|
||||
self.clock.tick(self.fps)
|
||||
rgb_obs = pygame.surfarray.array3d(self.screen)
|
||||
return np.transpose(rgb_obs, (2, 0, 1))
|
||||
# return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
renderer = Renderer(fps=2, cell_size=40)
|
||||
for pos_i in range(15):
|
||||
entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle')
|
||||
renderer.render([entity_1])
|
48
marl_factory_grid/utils/results.py
Normal file
48
marl_factory_grid/utils/results.py
Normal file
@ -0,0 +1,48 @@
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
|
||||
TYPE_VALUE = 'value'
|
||||
TYPE_REWARD = 'reward'
|
||||
types = [TYPE_VALUE, TYPE_REWARD]
|
||||
|
||||
@dataclass
|
||||
class InfoObject:
|
||||
identifier: str
|
||||
val_type: str
|
||||
value: Union[float, int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
identifier: str
|
||||
validity: bool
|
||||
reward: Union[float, None] = None
|
||||
value: Union[float, None] = None
|
||||
entity: Union[Entity, None] = None
|
||||
|
||||
def get_infos(self):
|
||||
n = self.entity.name if self.entity is not None else "Global"
|
||||
return [InfoObject(identifier=f'{n}_{self.identifier}_{t}',
|
||||
val_type=t, value=self.__getattribute__(t)) for t in types
|
||||
if self.__getattribute__(t) is not None]
|
||||
|
||||
def __repr__(self):
|
||||
valid = "not " if not self.validity else ""
|
||||
return f'{self.__class__.__name__}({self.identifier.capitalize()} {valid}valid: {self.reward})'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickResult(Result):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionResult(Result):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class DoneResult(Result):
|
||||
pass
|
112
marl_factory_grid/utils/states.py
Normal file
112
marl_factory_grid/utils/states.py
Normal file
@ -0,0 +1,112 @@
|
||||
from typing import List, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from marl_factory_grid.environment.entity.wall_floor import Floor
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.results import Result
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class StepRules:
|
||||
def __init__(self, *args):
|
||||
if args:
|
||||
self.rules = list(args)
|
||||
else:
|
||||
self.rules = list()
|
||||
|
||||
def __repr__(self):
|
||||
return f'Rules{[x.name for x in self]}'
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.rules)
|
||||
|
||||
def append(self, item):
|
||||
assert isinstance(item, Rule)
|
||||
self.rules.append(item)
|
||||
return True
|
||||
|
||||
def do_all_init(self, state):
|
||||
for rule in self.rules:
|
||||
if rule_init_printline := rule.on_init(state):
|
||||
state.print(rule_init_printline)
|
||||
return c.VALID
|
||||
|
||||
def tick_step_all(self, state):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if tick_step_result := rule.tick_step(state):
|
||||
results.extend(tick_step_result)
|
||||
return results
|
||||
|
||||
def tick_pre_step_all(self, state):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if tick_pre_step_result := rule.tick_post_step(state):
|
||||
results.extend(tick_pre_step_result)
|
||||
return results
|
||||
|
||||
def tick_post_step_all(self, state):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if tick_post_step_result := rule.tick_post_step(state):
|
||||
results.extend(tick_post_step_result)
|
||||
return results
|
||||
|
||||
|
||||
class Gamestate(object):
|
||||
|
||||
@property
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.can_move]
|
||||
|
||||
def __init__(self, entitites, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
self.entities = entitites
|
||||
self.NO_POS_TILE = Floor(c.VALUE_NO_POS)
|
||||
self.curr_step = 0
|
||||
self.curr_actions = None
|
||||
self.verbose = verbose
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.entities[item]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(e for e in self.entities.values())
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({len(self.entities)} Entitites @ Step {self.curr_step})'
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
results = list()
|
||||
self.curr_step += 1
|
||||
|
||||
# Main Agent Step
|
||||
results.extend(self.rules.tick_pre_step_all(self))
|
||||
for idx, action_int in enumerate(actions):
|
||||
agent = self[c.AGENT][idx].clear_temp_state()
|
||||
action = agent.actions[action_int]
|
||||
action_result = action.do(agent, self)
|
||||
results.append(action_result)
|
||||
agent.set_state(action_result)
|
||||
results.extend(self.rules.tick_step_all(self))
|
||||
results.extend(self.rules.tick_post_step_all(self))
|
||||
return results
|
||||
|
||||
def print(self, string):
|
||||
if self.verbose:
|
||||
print(string)
|
||||
|
||||
def check_done(self):
|
||||
results = list()
|
||||
for rule in self.rules:
|
||||
if on_check_done_result := rule.on_check_done(self):
|
||||
results.extend(on_check_done_result)
|
||||
return results
|
||||
|
||||
def get_all_tiles_with_collisions(self) -> List[Floor]:
|
||||
tiles = [self[c.FLOOR].by_pos(pos) for pos, e in self.entities.pos_dict.items()
|
||||
if sum([x.can_collide for x in e]) > 1]
|
||||
# tiles = [x for x in self[c.FLOOR] if len(x.guests_that_can_collide) > 1]
|
||||
return tiles
|
137
marl_factory_grid/utils/tools.py
Normal file
137
marl_factory_grid/utils/tools.py
Normal file
@ -0,0 +1,137 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
|
||||
ACTION = 'Action'
|
||||
GENERAL = 'General'
|
||||
ENTITIES = 'Objects'
|
||||
OBSERVATIONS = 'Observations'
|
||||
RULES = 'Rule'
|
||||
ASSETS = 'Assets'
|
||||
EXCLUDED = ['identifier', 'args', 'kwargs', 'Move',
|
||||
'TemplateRule', 'Entities', 'EnvObjects', 'Zones', ]
|
||||
|
||||
|
||||
class ConfigExplainer:
|
||||
|
||||
def __init__(self, custom_path: Union[None, PathLike] = None):
|
||||
self.base_path = Path(__file__).parent.parent.resolve()
|
||||
self.custom_path = custom_path
|
||||
self.searchspace = [ACTION, GENERAL, ENTITIES, OBSERVATIONS, RULES, ASSETS]
|
||||
|
||||
def explain_module(self, class_to_explain):
|
||||
parameters = inspect.signature(class_to_explain).parameters
|
||||
explained = {class_to_explain.__name__: {key: val.default for key, val in parameters.items() if key not in EXCLUDED}}
|
||||
return explained
|
||||
|
||||
def _load_and_compare(self, compare_class, paths):
|
||||
conf = {}
|
||||
package_pos = next(idx for idx, x in enumerate(Path(__file__).resolve().parts) if x == 'marl_factory_grid')
|
||||
for module_path in paths:
|
||||
module_parts = [x.replace('.py', '') for idx, x in enumerate(module_path.parts) if idx >= package_pos]
|
||||
mods = importlib.import_module('.'.join(module_parts))
|
||||
for key in mods.__dict__.keys():
|
||||
if key not in EXCLUDED and not key.startswith('_'):
|
||||
mod = mods.__getattribute__(key)
|
||||
try:
|
||||
if issubclass(mod, compare_class) and mod != compare_class:
|
||||
conf.update(self.explain_module(mod))
|
||||
except TypeError:
|
||||
pass
|
||||
return conf
|
||||
|
||||
def save_actions(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_actions.yml'):
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ACTION)
|
||||
|
||||
def get_actions(self):
|
||||
actions = self._get_by_identifier(ACTION)
|
||||
assert all(not x for x in actions.values()), 'Please only provide Names, no Mappings.'
|
||||
actions = list(actions.keys())
|
||||
actions.extend([c.MOVE8, c.MOVE4])
|
||||
# TODO: Print to file!
|
||||
return actions
|
||||
|
||||
def save_entities(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_entities.yml'):
|
||||
self._save_to_file(self.get_entities(), output_conf_file, ENTITIES)
|
||||
|
||||
def get_entities(self):
|
||||
entities = self._get_by_identifier(ENTITIES)
|
||||
return entities
|
||||
|
||||
def save_rules(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained_rules.yml'):
|
||||
self._save_to_file(self.get_entities(), output_conf_file, RULES)
|
||||
|
||||
def get_rules(self):
|
||||
rules = self._get_by_identifier(RULES)
|
||||
return rules
|
||||
|
||||
def get_assets(self):
|
||||
pass
|
||||
|
||||
def get_observations(self):
|
||||
names = [c.ALL, c.COMBINED, c.SELF, c.OTHERS, "Agent['ExampleAgentName']"]
|
||||
for key, val in self.get_entities().items():
|
||||
try:
|
||||
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
except AttributeError as err:
|
||||
if self.custom_path is not None:
|
||||
try:
|
||||
e = locate_and_import_class(key, self.base_path)(level_shape=(0, 0), pomdp_r=0).obs_pairs
|
||||
except TypeError:
|
||||
e = [key]
|
||||
else:
|
||||
raise err
|
||||
names.extend(e)
|
||||
return names
|
||||
|
||||
def _get_by_identifier(self, identifier):
|
||||
entities_base_cls = locate_and_import_class(identifier, self.base_path)
|
||||
module_paths = [x.resolve() for x in self.base_path.rglob('*.py') if x.is_file() and '__init__' not in x.name]
|
||||
found_entities = self._load_and_compare(entities_base_cls, module_paths)
|
||||
if self.custom_path is not None:
|
||||
module_paths = [x.resolve() for x in self.custom_path.rglob('*.py') if x.is_file()
|
||||
and '__init__' not in x.name]
|
||||
found_entities.update(self._load_and_compare(entities_base_cls, module_paths))
|
||||
return found_entities
|
||||
|
||||
def save_all(self, output_conf_file: PathLike = Path('../../quickstart') / 'explained.yml'):
|
||||
self._save_to_file(self.get_all(), output_conf_file, 'ALL')
|
||||
|
||||
def get_all(self):
|
||||
config_dict = {GENERAL: {'level_name': 'rooms', 'env_seed': 69, 'verbose': False,
|
||||
'pomdp_r': 3, 'individual_rewards': True},
|
||||
'Agents': dict(
|
||||
ExampleAgentName=dict(
|
||||
Actions=self.get_actions(),
|
||||
Observations=self.get_observations())),
|
||||
'Entities': self.get_entities(),
|
||||
'Rules': self.get_rules(),
|
||||
'Assets': self.get_assets()}
|
||||
return config_dict
|
||||
|
||||
def _save_to_file(self, data: dict, filepath: PathLike, tag: str = ''):
|
||||
filepath = Path(filepath)
|
||||
with filepath.open('w') as f:
|
||||
yaml.dump(data, f, encoding='utf-8')
|
||||
print(f'Example config {"for " + tag + " " if tag else " "}dumped')
|
||||
print(f'See file: {filepath}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ce = ConfigExplainer()
|
||||
ce.get_actions()
|
||||
ce.get_entities()
|
||||
ce.get_rules()
|
||||
ce.get_observations()
|
||||
ce.get_assets()
|
||||
all_conf = ce.get_all()
|
||||
print()
|
27
marl_factory_grid/utils/utility_classes.py
Normal file
27
marl_factory_grid/utils/utility_classes.py
Normal file
@ -0,0 +1,27 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class EnvCombiner(object):
|
||||
|
||||
def __init__(self, *envs_cls):
|
||||
self._env_dict = {env_cls.__name__: env_cls for env_cls in envs_cls}
|
||||
|
||||
@staticmethod
|
||||
def combine_cls(name, *envs_cls):
|
||||
return type(name, envs_cls, {})
|
||||
|
||||
def build(self):
|
||||
name = f'{"".join([x.lower().replace("factory").capitalize() for x in self._env_dict.keys()])}Factory'
|
||||
|
||||
return self.combine_cls(name, tuple(self._env_dict.values()))
|
||||
|
||||
|
||||
class MarlFrameStack(gym.ObservationWrapper):
|
||||
"""todo @romue404"""
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation):
|
||||
if isinstance(self.env, gym.wrappers.FrameStack) and self.env.unwrapped.n_agents > 1:
|
||||
return observation[0:].swapaxes(0, 1)
|
||||
return observation
|
Reference in New Issue
Block a user