mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
added test hooks (like rules)
This commit is contained in:
@ -1,16 +1,14 @@
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, List
|
||||
|
||||
import yaml
|
||||
|
||||
from marl_factory_grid.environment.groups.agents import Agents
|
||||
from marl_factory_grid.environment.entity.agent import Agent
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment.tests import Test
|
||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||
|
||||
DEFAULT_PATH = 'environment'
|
||||
MODULE_PATH = 'modules'
|
||||
@ -131,17 +129,25 @@ class FactoryConfigParser(object):
|
||||
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
||||
return parsed_agents_conf
|
||||
|
||||
def load_rules(self):
|
||||
# entites = Entities()
|
||||
rules_classes = dict()
|
||||
rules = []
|
||||
def load_env_rules(self) -> List[Rule]:
|
||||
rules = self.rules.copy()
|
||||
if c.DEFAULTS in self.rules:
|
||||
for rule in self.default_rules:
|
||||
if rule not in rules:
|
||||
rules.append(rule)
|
||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
||||
rules.append({rule: {}})
|
||||
|
||||
for rule in rules:
|
||||
return self._load_smth(rules, Rule)
|
||||
pass
|
||||
|
||||
def load_env_tests(self) -> List[Test]:
|
||||
return self._load_smth(self.tests, None) # Test
|
||||
pass
|
||||
|
||||
def _load_smth(self, config, class_obj):
|
||||
rules = list()
|
||||
rules_names = list()
|
||||
|
||||
for rule in rules_names:
|
||||
try:
|
||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||
rule_class = locate_and_import_class(rule, folder_path)
|
||||
@ -152,7 +158,7 @@ class FactoryConfigParser(object):
|
||||
except AttributeError:
|
||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||
# Fixme This check does not work!
|
||||
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
|
||||
rule_kwargs = self.rules.get(rule, {})
|
||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
||||
return rules_classes
|
||||
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
|
||||
rule_kwargs = config.get(rule, {})
|
||||
rules.append(rule_class(**rule_kwargs))
|
||||
return rules
|
||||
|
Reference in New Issue
Block a user