mirror of
				https://github.com/illiumst/marl-factory-grid.git
				synced 2025-10-31 12:37:27 +01:00 
			
		
		
		
	added test hooks (like rules)
This commit is contained in:
		| @@ -87,16 +87,19 @@ class Factory(gym.Env): | ||||
|         entities = self.map.do_init() | ||||
|  | ||||
|         # Init rules | ||||
|         rules = self.conf.load_rules() | ||||
|         rules = self.conf.load_env_rules() | ||||
|         env_tests = self.conf.load_env_tests() if self.conf.tests else [] | ||||
|  | ||||
|         # Parse the agent conf | ||||
|         parsed_agents_conf = self.conf.parse_agents_conf() | ||||
|         self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose) | ||||
|         self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose) | ||||
|  | ||||
|         # All is set up, trigger entity init with variable pos | ||||
|         # All is set up, trigger additional init (after agent entity spawn etc) | ||||
|         self.state.rules.do_all_init(self.state, self.map) | ||||
|  | ||||
|         self.state.tests.do_all_init(self.state, self.map) | ||||
|  | ||||
|         # Build initial observations for all agents | ||||
|         # noinspection PyAttributeOutsideInit | ||||
|         self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r) | ||||
|   | ||||
| @@ -1,16 +1,14 @@ | ||||
| import ast | ||||
| from collections import defaultdict | ||||
| from os import PathLike | ||||
| from pathlib import Path | ||||
| from typing import Union | ||||
| from typing import Union, List | ||||
|  | ||||
| import yaml | ||||
|  | ||||
| from marl_factory_grid.environment.groups.agents import Agents | ||||
| from marl_factory_grid.environment.entity.agent import Agent | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.utils.helpers import locate_and_import_class | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.environment.tests import Test | ||||
| from marl_factory_grid.utils.helpers import locate_and_import_class | ||||
|  | ||||
| DEFAULT_PATH = 'environment' | ||||
| MODULE_PATH = 'modules' | ||||
| @@ -131,17 +129,25 @@ class FactoryConfigParser(object): | ||||
|             parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) | ||||
|         return parsed_agents_conf | ||||
|  | ||||
|     def load_rules(self): | ||||
|         # entites = Entities() | ||||
|         rules_classes = dict() | ||||
|         rules = [] | ||||
|     def load_env_rules(self) -> List[Rule]: | ||||
|         rules = self.rules.copy() | ||||
|         if c.DEFAULTS in self.rules: | ||||
|             for rule in self.default_rules: | ||||
|                 if rule not in rules: | ||||
|                     rules.append(rule) | ||||
|         rules.extend(x for x in self.rules if x != c.DEFAULTS) | ||||
|                     rules.append({rule: {}}) | ||||
|  | ||||
|         for rule in rules: | ||||
|         return self._load_smth(rules, Rule) | ||||
|         pass | ||||
|  | ||||
|     def load_env_tests(self) -> List[Test]: | ||||
|         return self._load_smth(self.tests, None)  # Test | ||||
|         pass | ||||
|  | ||||
|     def _load_smth(self, config, class_obj): | ||||
|         rules = list() | ||||
|         rules_names = list() | ||||
|  | ||||
|         for rule in rules_names: | ||||
|             try: | ||||
|                 folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) | ||||
|                 rule_class = locate_and_import_class(rule, folder_path) | ||||
| @@ -152,7 +158,7 @@ class FactoryConfigParser(object): | ||||
|                 except AttributeError: | ||||
|                     rule_class = locate_and_import_class(rule, self.custom_modules_path) | ||||
|             # Fixme This check does not work! | ||||
|             #  assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".' | ||||
|             rule_kwargs = self.rules.get(rule, {}) | ||||
|             rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) | ||||
|         return rules_classes | ||||
|             #  assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".' | ||||
|             rule_kwargs = config.get(rule, {}) | ||||
|             rules.append(rule_class(**rule_kwargs)) | ||||
|         return rules | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import numpy as np | ||||
|  | ||||
| from marl_factory_grid.environment import constants as c | ||||
| from marl_factory_grid.environment.rules import Rule | ||||
| from marl_factory_grid.environment.tests import Test | ||||
| from marl_factory_grid.utils.results import Result | ||||
|  | ||||
|  | ||||
| @@ -59,14 +60,15 @@ class Gamestate(object): | ||||
|     def moving_entites(self): | ||||
|         return [y for x in self.entities for y in x if x.var_can_move] | ||||
|  | ||||
|     def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False): | ||||
|     def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False): | ||||
|         self.entities = entities | ||||
|         self.curr_step = 0 | ||||
|         self.curr_actions = None | ||||
|         self.agents_conf = agents_conf | ||||
|         self.verbose = verbose | ||||
|         self.rng = np.random.default_rng(env_seed) | ||||
|         self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values())) | ||||
|         self.rules = StepRules(*rules) | ||||
|         self.tests = StepTests(*tests) | ||||
|  | ||||
|     def __getitem__(self, item): | ||||
|         return self.entities[item] | ||||
| @@ -82,10 +84,13 @@ class Gamestate(object): | ||||
|  | ||||
|     def tick(self, actions) -> List[Result]: | ||||
|         results = list() | ||||
|         test_results = list() | ||||
|         self.curr_step += 1 | ||||
|  | ||||
|         # Main Agent Step | ||||
|         results.extend(self.rules.tick_pre_step_all(self)) | ||||
|         if self.tests: | ||||
|             test_results.extend(self.tests.tick_pre_step_all(self)) | ||||
|  | ||||
|         for idx, action_int in enumerate(actions): | ||||
|             agent = self[c.AGENT][idx].clear_temp_state() | ||||
| @@ -101,6 +106,10 @@ class Gamestate(object): | ||||
|         results.extend(self.rules.tick_step_all(self)) | ||||
|         results.extend(self.rules.tick_post_step_all(self)) | ||||
|  | ||||
|         if self.tests: | ||||
|             test_results.extend(self.tests.tick_step_all(self)) | ||||
|             test_results.extend(self.tests.tick_post_step_all(self)) | ||||
|  | ||||
|         return results | ||||
|  | ||||
|     def print(self, string): | ||||
| @@ -133,3 +142,48 @@ class Gamestate(object): | ||||
|         else: | ||||
|             return False | ||||
|  | ||||
|  | ||||
| class StepTests: | ||||
|     def __init__(self, *args): | ||||
|         if args: | ||||
|             self.tests = list(args) | ||||
|         else: | ||||
|             self.tests = list() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f'Tests{[x.name for x in self]}' | ||||
|  | ||||
|     def __iter__(self): | ||||
|         return iter(self.tests) | ||||
|  | ||||
|     def append(self, item): | ||||
|         assert isinstance(item, Test) | ||||
|         self.tests.append(item) | ||||
|         return True | ||||
|  | ||||
|     def do_all_init(self, state, lvl_map): | ||||
|         for test in self.tests: | ||||
|             if test_init_printline := test.on_init(state, lvl_map): | ||||
|                 state.print(test_init_printline) | ||||
|         return c.VALID | ||||
|  | ||||
|     def tick_step_all(self, state): | ||||
|         test_results = list() | ||||
|         for test in self.tests: | ||||
|             if tick_step_result := test.tick_step(state): | ||||
|                 test_results.extend(tick_step_result) | ||||
|         return test_results | ||||
|  | ||||
|     def tick_pre_step_all(self, state): | ||||
|         test_results = list() | ||||
|         for test in self.tests: | ||||
|             if tick_pre_step_result := test.tick_pre_step(state): | ||||
|                 test_results.extend(tick_pre_step_result) | ||||
|         return test_results | ||||
|  | ||||
|     def tick_post_step_all(self, state): | ||||
|         test_results = list() | ||||
|         for test in self.tests: | ||||
|             if tick_post_step_result := test.tick_post_step(state): | ||||
|                 test_results.extend(tick_post_step_result) | ||||
|         return test_results | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Chanumask
					Chanumask