import ast from collections import defaultdict from os import PathLike from pathlib import Path from typing import Union, List import yaml from marl_factory_grid.environment.rules import Rule from marl_factory_grid.utils.helpers import locate_and_import_class from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH from marl_factory_grid.environment import constants as c class FactoryConfigParser(object): default_entites = [] default_rules = ['DoneAtMaxStepsReached', 'WatchCollision'] default_actions = [c.MOVE8, c.NOOP] default_observations = [c.WALLS, c.AGENT] def __init__(self, config_path, custom_modules_path: Union[PathLike] = None): self.config_path = Path(config_path) self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path self.config = yaml.safe_load(self.config_path.open()) self._n_abbr_dict = None def __getattr__(self, item): return self['General'][item] def _get_sub_list(self, primary_key: str, sub_key: str): return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items() } for x in self.config.get(primary_key, [])] def _n_abbr(self, n): assert isinstance(n, int) if self._n_abbr_dict is None: self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'}) return self._n_abbr_dict[n] @property def agent_actions(self): return self._get_sub_list('Agents', "Actions") @property def agent_observations(self): return self._get_sub_list('Agents', "Observations") @property def rules(self): return self.config['Rules'] @property def tests(self): return self.config.get('Tests', []) @property def agents(self): return self.config['Agents'] @property def entities(self): return self.config['Entities'] def __repr__(self): return str(self.config) def __getitem__(self, item): try: return self.config[item] except KeyError: print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!') def load_entities(self): entity_classes = dict() entities = [] if c.DEFAULTS in self.entities: entities.extend(self.default_entites) entities.extend(x for x in self.entities if x != c.DEFAULTS) for entity in entities: e1 = e2 = e3 = None try: folder_path = Path(__file__).parent.parent / DEFAULT_PATH entity_class = locate_and_import_class(entity, folder_path) except AttributeError as e: e1 = e try: module_path = Path(__file__).parent.parent / MODULE_PATH entity_class = locate_and_import_class(entity, module_path) except AttributeError as e: e2 = e if self.custom_modules_path: try: entity_class = locate_and_import_class(entity, self.custom_modules_path) except AttributeError as e: e3 = e pass if (e1 and e2) or e3: ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] print('##############################################################') print('### Error ### Error ### Error ### Error ### Error ###') print('##############################################################') print(f'Class "{entity}" was not found in "{module_path.name}"') print(f'Class "{entity}" was not found in "{folder_path.name}"') print('##############################################################') if self.custom_modules_path: print(f'Class "{entity}" was not found in "{self.custom_modules_path}"') print('Possible Entitys are:', str(ents)) print('##############################################################') print('Goodbye') print('##############################################################') print('### Error ### Error ### Error ### Error ### Error ###') print('##############################################################') exit(-99999) entity_kwargs = self.entities.get(entity, {}) entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}}) return entity_classes def parse_agents_conf(self): parsed_agents_conf = dict() for name in self.agents: # Actions conf_actions = self.agents[name]['Actions'] actions = list() if isinstance(conf_actions, dict): conf_kwargs = conf_actions.copy() conf_actions = list(conf_actions.keys()) elif isinstance(conf_actions, list): conf_kwargs = {} if isinstance(conf_actions, dict): raise ValueError pass for action in conf_actions: if action == c.DEFAULTS: actions.extend(self.default_actions) else: actions.append(action) parsed_actions = list() for action in actions: folder_path = MODULE_PATH if action not in [c.MOVE8, c.NOOP, c.MOVE4] else DEFAULT_PATH folder_path = Path(__file__).parent.parent / folder_path try: class_or_classes = locate_and_import_class(action, folder_path) except AttributeError: class_or_classes = locate_and_import_class(action, self.custom_modules_path) try: parsed_actions.extend(class_or_classes) for actions_class in class_or_classes: conf_kwargs[actions_class.__name__] = conf_kwargs[action] except TypeError: parsed_actions.append(class_or_classes) parsed_actions = [x(**conf_kwargs.get(x.__name__, {})) for x in parsed_actions] # Observation observations = list() assert self.agents[name]['Observations'] is not None, 'Did you specify any Observation?' if c.DEFAULTS in self.agents[name]['Observations']: observations.extend(self.default_observations) observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS) positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])] other_kwargs = {k: v for k, v in self.agents[name].items() if k not in ['Actions', 'Observations', 'Positions', 'Clones']} parsed_agents_conf[name] = dict( actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs ) clones = self.agents[name].get('Clones', 0) if clones: if isinstance(clones, int): clones = [f'{name}_the_{n}{self._n_abbr(n)}' for n in range(clones)] for clone in clones: parsed_agents_conf[clone] = parsed_agents_conf[name].copy() return parsed_agents_conf def load_env_rules(self) -> List[Rule]: rules = self.rules.copy() if c.DEFAULTS in self.rules: for rule in self.default_rules: if rule not in rules: rules.append({rule: {}}) return self._load_smth(rules, Rule) def load_env_tests(self) -> List[Rule]: return self._load_smth(self.tests, None) # Test def _load_smth(self, config, class_obj): rules = list() for rule in config: e1 = e2 = e3 = None try: folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) rule_class = locate_and_import_class(rule, folder_path) except AttributeError as e: e1 = e try: module_path = (Path(__file__).parent.parent / MODULE_PATH) rule_class = locate_and_import_class(rule, module_path) except AttributeError as e: e2 = e if self.custom_modules_path: try: rule_class = locate_and_import_class(rule, self.custom_modules_path) except AttributeError as e: e3 = e pass if (e1 and e2) or e3: ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]] print('### Error ### Error ### Error ### Error ### Error ###') print('') print(f'Class "{rule}" was not found in "{module_path.name}"') print(f'Class "{rule}" was not found in "{folder_path.name}"') if self.custom_modules_path: print(f'Class "{rule}" was not found in "{self.custom_modules_path}"') print('Possible Entitys are:', str(ents)) print('') print('Goodbye') print('') exit(-99999) if issubclass(rule_class, class_obj): rule_kwargs = config.get(rule, {}) rules.append(rule_class(**(rule_kwargs or {}))) return rules def load_entity_spawn_rules(self, entities) -> List[Rule]: rules = list() rules_dicts = list() for e in entities: try: if spawn_rule := e.spawn_rule: rules_dicts.append(spawn_rule) except AttributeError: pass for rule_dict in rules_dicts: for rule_name, rule_kwargs in rule_dict.items(): try: folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) rule_class = locate_and_import_class(rule_name, folder_path) except AttributeError: try: folder_path = (Path(__file__).parent.parent / MODULE_PATH) rule_class = locate_and_import_class(rule_name, folder_path) except AttributeError: rule_class = locate_and_import_class(rule_name, self.custom_modules_path) rules.append(rule_class(**rule_kwargs)) return rules