Steffen Illium 803d0dae7f Multiple Fixes:
- Config Explainer
 - Rewards
 - Destination Reach Condition
 - Additional Step Callback
2023-11-24 14:43:49 +01:00

257 lines
11 KiB
Python

import ast
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union, List
import yaml
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import locate_and_import_class
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
from marl_factory_grid.environment import constants as c
class FactoryConfigParser(object):
default_entites = []
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
default_actions = [c.MOVE8, c.NOOP]
default_observations = [c.WALLS, c.AGENT]
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
self.config_path = Path(config_path)
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
self.config = yaml.safe_load(self.config_path.open())
self._n_abbr_dict = None
def __getattr__(self, item):
return self['General'][item]
def _get_sub_list(self, primary_key: str, sub_key: str):
return [{key: [s for k, v in val.items() if k == sub_key for s in v] for key, val in x.items()
} for x in self.config.get(primary_key, [])]
def _n_abbr(self, n):
assert isinstance(n, int)
if self._n_abbr_dict is None:
self._n_abbr_dict = defaultdict(lambda: 'th', {1: 'st', 2: 'nd', 3: 'rd'})
return self._n_abbr_dict[n]
@property
def agent_actions(self):
return self._get_sub_list('Agents', "Actions")
@property
def agent_observations(self):
return self._get_sub_list('Agents', "Observations")
@property
def rules(self):
return self.config['Rules']
@property
def tests(self):
return self.config.get('Tests', [])
@property
def agents(self):
return self.config['Agents']
@property
def entities(self):
return self.config['Entities']
def __repr__(self):
return str(self.config)
def __getitem__(self, item):
try:
return self.config[item]
except KeyError:
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
def load_entities(self):
entity_classes = dict()
entities = []
if c.DEFAULTS in self.entities:
entities.extend(self.default_entites)
entities.extend(x for x in self.entities if x != c.DEFAULTS)
for entity in entities:
e1 = e2 = e3 = None
try:
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
entity_class = locate_and_import_class(entity, folder_path)
except AttributeError as e:
e1 = e
try:
module_path = Path(__file__).parent.parent / MODULE_PATH
entity_class = locate_and_import_class(entity, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
entity_class = locate_and_import_class(entity, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
print(f'Class "{entity}" was not found in "{module_path.name}"')
print(f'Class "{entity}" was not found in "{folder_path.name}"')
print('##############################################################')
if self.custom_modules_path:
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('##############################################################')
print('Goodbye')
print('##############################################################')
print('### Error ### Error ### Error ### Error ### Error ###')
print('##############################################################')
exit(-99999)
entity_kwargs = self.entities.get(entity, {})
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
entity_classes.update({entity: {'class': entity_class, 'kwargs': entity_kwargs, 'symbol': entity_symbol}})
return entity_classes
def parse_agents_conf(self):
parsed_agents_conf = dict()
for name in self.agents:
# Actions
conf_actions = self.agents[name]['Actions']
actions = list()
if isinstance(conf_actions, dict):
conf_kwargs = conf_actions.copy()
conf_actions = list(conf_actions.keys())
elif isinstance(conf_actions, list):
conf_kwargs = {}
if isinstance(conf_actions, dict):
raise ValueError
pass
for action in conf_actions:
if action == c.DEFAULTS:
actions.extend(self.default_actions)
else:
actions.append(action)
parsed_actions = list()
for action in actions:
folder_path = MODULE_PATH if action not in [c.MOVE8, c.NOOP, c.MOVE4] else DEFAULT_PATH
folder_path = Path(__file__).parent.parent / folder_path
try:
class_or_classes = locate_and_import_class(action, folder_path)
except AttributeError:
class_or_classes = locate_and_import_class(action, self.custom_modules_path)
try:
parsed_actions.extend(class_or_classes)
for actions_class in class_or_classes:
conf_kwargs[actions_class.__name__] = conf_kwargs[action]
except TypeError:
parsed_actions.append(class_or_classes)
parsed_actions = [x(**conf_kwargs.get(x.__name__, {})) for x in parsed_actions]
# Observation
observations = list()
assert self.agents[name]['Observations'] is not None, 'Did you specify any Observation?'
if c.DEFAULTS in self.agents[name]['Observations']:
observations.extend(self.default_observations)
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
['Actions', 'Observations', 'Positions', 'Clones']}
parsed_agents_conf[name] = dict(
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
)
clones = self.agents[name].get('Clones', 0)
if clones:
if isinstance(clones, int):
clones = [f'{name}_the_{n}{self._n_abbr(n)}' for n in range(clones)]
for clone in clones:
parsed_agents_conf[clone] = parsed_agents_conf[name].copy()
return parsed_agents_conf
def load_env_rules(self) -> List[Rule]:
rules = self.rules.copy()
if c.DEFAULTS in self.rules:
for rule in self.default_rules:
if rule not in rules:
rules.append({rule: {}})
return self._load_smth(rules, Rule)
def load_env_tests(self) -> List[Rule]:
return self._load_smth(self.tests, None) # Test
def _load_smth(self, config, class_obj):
rules = list()
for rule in config:
e1 = e2 = e3 = None
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
except AttributeError as e:
e1 = e
try:
module_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule, module_path)
except AttributeError as e:
e2 = e
if self.custom_modules_path:
try:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
except AttributeError as e:
e3 = e
pass
if (e1 and e2) or e3:
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
print('### Error ### Error ### Error ### Error ### Error ###')
print('')
print(f'Class "{rule}" was not found in "{module_path.name}"')
print(f'Class "{rule}" was not found in "{folder_path.name}"')
if self.custom_modules_path:
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
print('Possible Entitys are:', str(ents))
print('')
print('Goodbye')
print('')
exit(-99999)
if issubclass(rule_class, class_obj):
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**(rule_kwargs or {})))
return rules
def load_entity_spawn_rules(self, entities) -> List[Rule]:
rules = list()
rules_dicts = list()
for e in entities:
try:
if spawn_rule := e.spawn_rule:
rules_dicts.append(spawn_rule)
except AttributeError:
pass
for rule_dict in rules_dicts:
for rule_name, rule_kwargs in rule_dict.items():
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
try:
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
rule_class = locate_and_import_class(rule_name, folder_path)
except AttributeError:
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
rules.append(rule_class(**rule_kwargs))
return rules