From 64c0d0e4e9c318fb9e1b8a8df71f3f759f0086be Mon Sep 17 00:00:00 2001 From: Chanumask Date: Fri, 10 Nov 2023 10:41:41 +0100 Subject: [PATCH] added test hooks (like rules) --- marl_factory_grid/environment/factory.py | 7 ++- marl_factory_grid/utils/config_parser.py | 40 +++++++++------- marl_factory_grid/utils/states.py | 58 +++++++++++++++++++++++- 3 files changed, 84 insertions(+), 21 deletions(-) diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index d840178..651444e 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -87,16 +87,19 @@ class Factory(gym.Env): entities = self.map.do_init() # Init rules - rules = self.conf.load_rules() + rules = self.conf.load_env_rules() + env_tests = self.conf.load_env_tests() if self.conf.tests else [] # Parse the agent conf parsed_agents_conf = self.conf.parse_agents_conf() - self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose) + self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose) # All is set up, trigger entity init with variable pos # All is set up, trigger additional init (after agent entity spawn etc) self.state.rules.do_all_init(self.state, self.map) + self.state.tests.do_all_init(self.state, self.map) + # Build initial observations for all agents # noinspection PyAttributeOutsideInit self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r) diff --git a/marl_factory_grid/utils/config_parser.py b/marl_factory_grid/utils/config_parser.py index 093f1d0..c9223f8 100644 --- a/marl_factory_grid/utils/config_parser.py +++ b/marl_factory_grid/utils/config_parser.py @@ -1,16 +1,14 @@ import ast -from collections import defaultdict from os import PathLike from pathlib import Path -from typing import Union +from typing import Union, List import yaml -from marl_factory_grid.environment.groups.agents import Agents -from marl_factory_grid.environment.entity.agent import Agent -from marl_factory_grid.environment.rules import Rule -from marl_factory_grid.utils.helpers import locate_and_import_class from marl_factory_grid.environment import constants as c +from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.environment.tests import Test +from marl_factory_grid.utils.helpers import locate_and_import_class DEFAULT_PATH = 'environment' MODULE_PATH = 'modules' @@ -131,17 +129,25 @@ class FactoryConfigParser(object): parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) return parsed_agents_conf - def load_rules(self): - # entites = Entities() - rules_classes = dict() - rules = [] + def load_env_rules(self) -> List[Rule]: + rules = self.rules.copy() if c.DEFAULTS in self.rules: for rule in self.default_rules: if rule not in rules: - rules.append(rule) - rules.extend(x for x in self.rules if x != c.DEFAULTS) + rules.append({rule: {}}) - for rule in rules: + return self._load_smth(rules, Rule) + pass + + def load_env_tests(self) -> List[Test]: + return self._load_smth(self.tests, None) # Test + pass + + def _load_smth(self, config, class_obj): + rules = list() + rules_names = list() + + for rule in rules_names: try: folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) rule_class = locate_and_import_class(rule, folder_path) @@ -152,7 +158,7 @@ class FactoryConfigParser(object): except AttributeError: rule_class = locate_and_import_class(rule, self.custom_modules_path) # Fixme This check does not work! - # assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".' - rule_kwargs = self.rules.get(rule, {}) - rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) - return rules_classes + # assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".' + rule_kwargs = config.get(rule, {}) + rules.append(rule_class(**rule_kwargs)) + return rules diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index 4c1f7f2..1461826 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -4,6 +4,7 @@ import numpy as np from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.rules import Rule +from marl_factory_grid.environment.tests import Test from marl_factory_grid.utils.results import Result @@ -59,14 +60,15 @@ class Gamestate(object): def moving_entites(self): return [y for x in self.entities for y in x if x.var_can_move] - def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False): + def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False): self.entities = entities self.curr_step = 0 self.curr_actions = None self.agents_conf = agents_conf self.verbose = verbose self.rng = np.random.default_rng(env_seed) - self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values())) + self.rules = StepRules(*rules) + self.tests = StepTests(*tests) def __getitem__(self, item): return self.entities[item] @@ -82,10 +84,13 @@ class Gamestate(object): def tick(self, actions) -> List[Result]: results = list() + test_results = list() self.curr_step += 1 # Main Agent Step results.extend(self.rules.tick_pre_step_all(self)) + if self.tests: + test_results.extend(self.tests.tick_pre_step_all(self)) for idx, action_int in enumerate(actions): agent = self[c.AGENT][idx].clear_temp_state() @@ -101,6 +106,10 @@ class Gamestate(object): results.extend(self.rules.tick_step_all(self)) results.extend(self.rules.tick_post_step_all(self)) + if self.tests: + test_results.extend(self.tests.tick_step_all(self)) + test_results.extend(self.tests.tick_post_step_all(self)) + return results def print(self, string): @@ -133,3 +142,48 @@ class Gamestate(object): else: return False + +class StepTests: + def __init__(self, *args): + if args: + self.tests = list(args) + else: + self.tests = list() + + def __repr__(self): + return f'Tests{[x.name for x in self]}' + + def __iter__(self): + return iter(self.tests) + + def append(self, item): + assert isinstance(item, Test) + self.tests.append(item) + return True + + def do_all_init(self, state, lvl_map): + for test in self.tests: + if test_init_printline := test.on_init(state, lvl_map): + state.print(test_init_printline) + return c.VALID + + def tick_step_all(self, state): + test_results = list() + for test in self.tests: + if tick_step_result := test.tick_step(state): + test_results.extend(tick_step_result) + return test_results + + def tick_pre_step_all(self, state): + test_results = list() + for test in self.tests: + if tick_pre_step_result := test.tick_pre_step(state): + test_results.extend(tick_pre_step_result) + return test_results + + def tick_post_step_all(self, state): + test_results = list() + for test in self.tests: + if tick_post_step_result := test.tick_post_step(state): + test_results.extend(tick_post_step_result) + return test_results