mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
165 lines
6.6 KiB
Python
165 lines
6.6 KiB
Python
import ast
|
|
from os import PathLike
|
|
from pathlib import Path
|
|
from typing import Union, List
|
|
|
|
import yaml
|
|
|
|
from marl_factory_grid.environment import constants as c
|
|
from marl_factory_grid.environment.rules import Rule
|
|
from marl_factory_grid.environment.tests import Test
|
|
from marl_factory_grid.utils.helpers import locate_and_import_class
|
|
|
|
DEFAULT_PATH = 'environment'
|
|
MODULE_PATH = 'modules'
|
|
|
|
|
|
class FactoryConfigParser(object):
|
|
default_entites = []
|
|
default_rules = ['MaxStepsReached', 'Collision']
|
|
default_actions = [c.MOVE8, c.NOOP]
|
|
default_observations = [c.WALLS, c.AGENT]
|
|
|
|
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
|
self.config_path = Path(config_path)
|
|
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
|
self.config = yaml.safe_load(self.config_path.open())
|
|
|
|
def __getattr__(self, item):
|
|
return self['General'][item]
|
|
|
|
def _get_sub_list(self, primary_key: str, sub_key: str):
|
|
return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items()
|
|
} for x in self.config[primary_key]]
|
|
|
|
@property
|
|
def agent_actions(self):
|
|
return self._get_sub_list('Agents', "Actions")
|
|
|
|
@property
|
|
def agent_observations(self):
|
|
return self._get_sub_list('Agents', "Observations")
|
|
|
|
@property
|
|
def rules(self):
|
|
return self.config['Rules']
|
|
|
|
@property
|
|
def agents(self):
|
|
return self.config['Agents']
|
|
|
|
@property
|
|
def entities(self):
|
|
return self.config['Entities']
|
|
|
|
def __repr__(self):
|
|
return str(self.config)
|
|
|
|
def __getitem__(self, item):
|
|
return self.config[item]
|
|
|
|
def load_entities(self):
|
|
# entites = Entities()
|
|
entity_classes = dict()
|
|
entities = []
|
|
if c.DEFAULTS in self.entities:
|
|
entities.extend(self.default_entites)
|
|
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
|
|
|
for entity in entities:
|
|
try:
|
|
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
|
entity_class = locate_and_import_class(entity, folder_path)
|
|
except AttributeError as e1:
|
|
try:
|
|
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
|
entity_class = locate_and_import_class(entity, folder_path)
|
|
except AttributeError as e2:
|
|
try:
|
|
folder_path = self.custom_modules_path
|
|
entity_class = locate_and_import_class(entity, folder_path)
|
|
except AttributeError as e3:
|
|
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
|
print('### Error ### Error ### Error ### Error ### Error ###')
|
|
print()
|
|
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
|
print('Possible Entitys are:', str(ents))
|
|
print()
|
|
print('Goodbye')
|
|
print()
|
|
exit()
|
|
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
|
|
|
entity_kwargs = self.entities.get(entity, {})
|
|
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
|
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
|
|
return entity_classes
|
|
|
|
def parse_agents_conf(self):
|
|
parsed_agents_conf = dict()
|
|
base_env_actions = self.default_actions.copy() + [c.MOVE4]
|
|
for name in self.agents:
|
|
# Actions
|
|
actions = list()
|
|
if c.DEFAULTS in self.agents[name]['Actions']:
|
|
actions.extend(self.default_actions)
|
|
actions.extend(x for x in self.agents[name]['Actions'] if x != c.DEFAULTS)
|
|
parsed_actions = list()
|
|
for action in actions:
|
|
folder_path = MODULE_PATH if action not in base_env_actions else DEFAULT_PATH
|
|
folder_path = Path(__file__).parent.parent / folder_path
|
|
try:
|
|
class_or_classes = locate_and_import_class(action, folder_path)
|
|
except AttributeError:
|
|
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
|
|
try:
|
|
parsed_actions.extend(class_or_classes)
|
|
except TypeError:
|
|
parsed_actions.append(class_or_classes)
|
|
|
|
parsed_actions = [x() for x in parsed_actions]
|
|
|
|
# Observation
|
|
observations = list()
|
|
assert self.agents[name]['Observations'] is not None, 'Did you specify any Observation?'
|
|
if c.DEFAULTS in self.agents[name]['Observations']:
|
|
observations.extend(self.default_observations)
|
|
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
|
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
|
|
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
|
return parsed_agents_conf
|
|
|
|
def load_env_rules(self) -> List[Rule]:
|
|
rules = self.rules.copy()
|
|
if c.DEFAULTS in self.rules:
|
|
for rule in self.default_rules:
|
|
if rule not in rules:
|
|
rules.append({rule: {}})
|
|
|
|
return self._load_smth(rules, Rule)
|
|
pass
|
|
|
|
def load_env_tests(self) -> List[Test]:
|
|
return self._load_smth(self.tests, None) # Test
|
|
pass
|
|
|
|
def _load_smth(self, config, class_obj):
|
|
rules = list()
|
|
rules_names = list()
|
|
|
|
for rule in rules_names:
|
|
try:
|
|
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
|
rule_class = locate_and_import_class(rule, folder_path)
|
|
except AttributeError:
|
|
try:
|
|
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
|
rule_class = locate_and_import_class(rule, folder_path)
|
|
except AttributeError:
|
|
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
|
# Fixme This check does not work!
|
|
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
|
|
rule_kwargs = config.get(rule, {})
|
|
rules.append(rule_class(**rule_kwargs))
|
|
return rules
|