added test hooks (like rules)

This commit is contained in:
Chanumask
2023-11-10 10:41:41 +01:00
parent 9b9c6e0385
commit 64c0d0e4e9
3 changed files with 84 additions and 21 deletions

View File

@ -1,16 +1,14 @@
import ast
from collections import defaultdict
from os import PathLike
from pathlib import Path
from typing import Union
from typing import Union, List
import yaml
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import locate_and_import_class
from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.helpers import locate_and_import_class
DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules'
@ -131,17 +129,25 @@ class FactoryConfigParser(object):
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
return parsed_agents_conf
def load_rules(self):
# entites = Entities()
rules_classes = dict()
rules = []
def load_env_rules(self) -> List[Rule]:
rules = self.rules.copy()
if c.DEFAULTS in self.rules:
for rule in self.default_rules:
if rule not in rules:
rules.append(rule)
rules.extend(x for x in self.rules if x != c.DEFAULTS)
rules.append({rule: {}})
for rule in rules:
return self._load_smth(rules, Rule)
pass
def load_env_tests(self) -> List[Test]:
return self._load_smth(self.tests, None) # Test
pass
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in rules_names:
try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path)
@ -152,7 +158,7 @@ class FactoryConfigParser(object):
except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path)
# Fixme This check does not work!
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
rule_kwargs = self.rules.get(rule, {})
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
return rules_classes
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
rule_kwargs = config.get(rule, {})
rules.append(rule_class(**rule_kwargs))
return rules