mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
Merge branch 'main' into unit_testing
# Conflicts: # marl_factory_grid/environment/factory.py # marl_factory_grid/utils/config_parser.py # marl_factory_grid/utils/states.py
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import ast
|
||||
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union, List
|
||||
@ -9,18 +10,17 @@ from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment.tests import Test
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
|
||||
DEFAULT_PATH = 'environment'
|
||||
MODULE_PATH = 'modules'
|
||||
from marl_factory_grid.environment.constants import DEFAULT_PATH, MODULE_PATH
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
class FactoryConfigParser(object):
|
||||
default_entites = []
|
||||
default_rules = ['MaxStepsReached', 'Collision']
|
||||
default_rules = ['DoneAtMaxStepsReached', 'WatchCollision']
|
||||
default_actions = [c.MOVE8, c.NOOP]
|
||||
default_observations = [c.WALLS, c.AGENT]
|
||||
|
||||
def __init__(self, config_path, custom_modules_path: Union[None, PathLike] = None):
|
||||
def __init__(self, config_path, custom_modules_path: Union[PathLike] = None):
|
||||
self.config_path = Path(config_path)
|
||||
self.custom_modules_path = Path(custom_modules_path) if custom_modules_path is not None else custom_modules_path
|
||||
self.config = yaml.safe_load(self.config_path.open())
|
||||
@ -44,6 +44,10 @@ class FactoryConfigParser(object):
|
||||
def rules(self):
|
||||
return self.config['Rules']
|
||||
|
||||
@property
|
||||
def tests(self):
|
||||
return self.config.get('Tests', [])
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return self.config['Agents']
|
||||
@ -56,10 +60,12 @@ class FactoryConfigParser(object):
|
||||
return str(self.config)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.config[item]
|
||||
try:
|
||||
return self.config[item]
|
||||
except KeyError:
|
||||
print(f'The mandatory {item} section could not be found in your .config gile. Check Spelling!')
|
||||
|
||||
def load_entities(self):
|
||||
# entites = Entities()
|
||||
entity_classes = dict()
|
||||
entities = []
|
||||
if c.DEFAULTS in self.entities:
|
||||
@ -67,28 +73,40 @@ class FactoryConfigParser(object):
|
||||
entities.extend(x for x in self.entities if x != c.DEFAULTS)
|
||||
|
||||
for entity in entities:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / DEFAULT_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e1:
|
||||
except AttributeError as e:
|
||||
e1 = e
|
||||
try:
|
||||
folder_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e2:
|
||||
try:
|
||||
folder_path = self.custom_modules_path
|
||||
entity_class = locate_and_import_class(entity, folder_path)
|
||||
except AttributeError as e3:
|
||||
ents = [y for x in [e1.argss[1], e2.argss[1], e3.argss[1]] for y in x]
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print()
|
||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print()
|
||||
print('Goodbye')
|
||||
print()
|
||||
exit()
|
||||
# raise AttributeError(e1.argss[0], e2.argss[0], e3.argss[0], 'Possible Entitys are:', str(ents))
|
||||
module_path = Path(__file__).parent.parent / MODULE_PATH
|
||||
entity_class = locate_and_import_class(entity, module_path)
|
||||
except AttributeError as e:
|
||||
e2 = e
|
||||
if self.custom_modules_path:
|
||||
try:
|
||||
entity_class = locate_and_import_class(entity, self.custom_modules_path)
|
||||
except AttributeError as e:
|
||||
e3 = e
|
||||
pass
|
||||
if (e1 and e2) or e3:
|
||||
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||
print('##############################################################')
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('##############################################################')
|
||||
print(f'Class "{entity}" was not found in "{module_path.name}"')
|
||||
print(f'Class "{entity}" was not found in "{folder_path.name}"')
|
||||
print('##############################################################')
|
||||
if self.custom_modules_path:
|
||||
print(f'Class "{entity}" was not found in "{self.custom_modules_path}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print('##############################################################')
|
||||
print('Goodbye')
|
||||
print('##############################################################')
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('##############################################################')
|
||||
exit(-99999)
|
||||
|
||||
entity_kwargs = self.entities.get(entity, {})
|
||||
entity_symbol = entity_class.symbol if hasattr(entity_class, 'symbol') else None
|
||||
@ -126,7 +144,12 @@ class FactoryConfigParser(object):
|
||||
observations.extend(self.default_observations)
|
||||
observations.extend(x for x in self.agents[name]['Observations'] if x != c.DEFAULTS)
|
||||
positions = [ast.literal_eval(x) for x in self.agents[name].get('Positions', [])]
|
||||
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
||||
other_kwargs = {k: v for k, v in self.agents[name].items() if k not in
|
||||
['Actions', 'Observations', 'Positions']}
|
||||
parsed_agents_conf[name] = dict(
|
||||
actions=parsed_actions, observations=observations, positions=positions, other=other_kwargs
|
||||
)
|
||||
|
||||
return parsed_agents_conf
|
||||
|
||||
def load_env_rules(self) -> List[Rule]:
|
||||
@ -137,28 +160,69 @@ class FactoryConfigParser(object):
|
||||
rules.append({rule: {}})
|
||||
|
||||
return self._load_smth(rules, Rule)
|
||||
pass
|
||||
|
||||
def load_env_tests(self) -> List[Test]:
|
||||
def load_env_tests(self) -> List[Rule]:
|
||||
return self._load_smth(self.tests, None) # Test
|
||||
pass
|
||||
|
||||
def _load_smth(self, config, class_obj):
|
||||
rules = list()
|
||||
rules_names = list()
|
||||
|
||||
for rule in rules_names:
|
||||
for rule in config:
|
||||
e1 = e2 = e3 = None
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
e1 = e
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
# Fixme This check does not work!
|
||||
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
|
||||
rule_kwargs = config.get(rule, {})
|
||||
rules.append(rule_class(**rule_kwargs))
|
||||
module_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule, module_path)
|
||||
except AttributeError as e:
|
||||
e2 = e
|
||||
if self.custom_modules_path:
|
||||
try:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
except AttributeError as e:
|
||||
e3 = e
|
||||
pass
|
||||
if (e1 and e2) or e3:
|
||||
ents = [y for x in [e1, e2, e3] if x is not None for y in x.args[1]]
|
||||
print('### Error ### Error ### Error ### Error ### Error ###')
|
||||
print('')
|
||||
print(f'Class "{rule}" was not found in "{module_path.name}"')
|
||||
print(f'Class "{rule}" was not found in "{folder_path.name}"')
|
||||
if self.custom_modules_path:
|
||||
print(f'Class "{rule}" was not found in "{self.custom_modules_path}"')
|
||||
print('Possible Entitys are:', str(ents))
|
||||
print('')
|
||||
print('Goodbye')
|
||||
print('')
|
||||
exit(-99999)
|
||||
|
||||
if issubclass(rule_class, class_obj):
|
||||
rule_kwargs = config.get(rule, {})
|
||||
rules.append(rule_class(**(rule_kwargs or {})))
|
||||
return rules
|
||||
|
||||
def load_entity_spawn_rules(self, entities) -> List[Rule]:
|
||||
rules = list()
|
||||
rules_dicts = list()
|
||||
for e in entities:
|
||||
try:
|
||||
if spawn_rule := e.spawn_rule:
|
||||
rules_dicts.append(spawn_rule)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
for rule_dict in rules_dicts:
|
||||
for rule_name, rule_kwargs in rule_dict.items():
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||
except AttributeError:
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / MODULE_PATH)
|
||||
rule_class = locate_and_import_class(rule_name, folder_path)
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule_name, self.custom_modules_path)
|
||||
rules.append(rule_class(**rule_kwargs))
|
||||
return rules
|
||||
|
Reference in New Issue
Block a user