added test hooks (like rules)

This commit is contained in:
Chanumask 2023-11-10 10:41:41 +01:00
parent 9b9c6e0385
commit 64c0d0e4e9
3 changed files with 84 additions and 21 deletions

View File

@ -87,16 +87,19 @@ class Factory(gym.Env):
entities = self.map.do_init() entities = self.map.do_init()
# Init rules # Init rules
rules = self.conf.load_rules() rules = self.conf.load_env_rules()
env_tests = self.conf.load_env_tests() if self.conf.tests else []
# Parse the agent conf # Parse the agent conf
parsed_agents_conf = self.conf.parse_agents_conf() parsed_agents_conf = self.conf.parse_agents_conf()
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose) self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose)
# All is set up, trigger entity init with variable pos # All is set up, trigger entity init with variable pos
# All is set up, trigger additional init (after agent entity spawn etc) # All is set up, trigger additional init (after agent entity spawn etc)
self.state.rules.do_all_init(self.state, self.map) self.state.rules.do_all_init(self.state, self.map)
self.state.tests.do_all_init(self.state, self.map)
# Build initial observations for all agents # Build initial observations for all agents
# noinspection PyAttributeOutsideInit # noinspection PyAttributeOutsideInit
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r) self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)

View File

@ -1,16 +1,14 @@
import ast import ast
from collections import defaultdict
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union, List
import yaml import yaml
from marl_factory_grid.environment.groups.agents import Agents
from marl_factory_grid.environment.entity.agent import Agent
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.utils.helpers import locate_and_import_class
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.helpers import locate_and_import_class
DEFAULT_PATH = 'environment' DEFAULT_PATH = 'environment'
MODULE_PATH = 'modules' MODULE_PATH = 'modules'
@ -131,17 +129,25 @@ class FactoryConfigParser(object):
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions) parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
return parsed_agents_conf return parsed_agents_conf
def load_rules(self): def load_env_rules(self) -> List[Rule]:
# entites = Entities() rules = self.rules.copy()
rules_classes = dict()
rules = []
if c.DEFAULTS in self.rules: if c.DEFAULTS in self.rules:
for rule in self.default_rules: for rule in self.default_rules:
if rule not in rules: if rule not in rules:
rules.append(rule) rules.append({rule: {}})
rules.extend(x for x in self.rules if x != c.DEFAULTS)
for rule in rules: return self._load_smth(rules, Rule)
pass
def load_env_tests(self) -> List[Test]:
return self._load_smth(self.tests, None) # Test
pass
def _load_smth(self, config, class_obj):
rules = list()
rules_names = list()
for rule in rules_names:
try: try:
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH) folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
rule_class = locate_and_import_class(rule, folder_path) rule_class = locate_and_import_class(rule, folder_path)
@ -152,7 +158,7 @@ class FactoryConfigParser(object):
except AttributeError: except AttributeError:
rule_class = locate_and_import_class(rule, self.custom_modules_path) rule_class = locate_and_import_class(rule, self.custom_modules_path)
# Fixme This check does not work! # Fixme This check does not work!
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".' # assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
rule_kwargs = self.rules.get(rule, {}) rule_kwargs = config.get(rule, {})
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}}) rules.append(rule_class(**rule_kwargs))
return rules_classes return rules

View File

@ -4,6 +4,7 @@ import numpy as np
from marl_factory_grid.environment import constants as c from marl_factory_grid.environment import constants as c
from marl_factory_grid.environment.rules import Rule from marl_factory_grid.environment.rules import Rule
from marl_factory_grid.environment.tests import Test
from marl_factory_grid.utils.results import Result from marl_factory_grid.utils.results import Result
@ -59,14 +60,15 @@ class Gamestate(object):
def moving_entites(self): def moving_entites(self):
return [y for x in self.entities for y in x if x.var_can_move] return [y for x in self.entities for y in x if x.var_can_move]
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False): def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False):
self.entities = entities self.entities = entities
self.curr_step = 0 self.curr_step = 0
self.curr_actions = None self.curr_actions = None
self.agents_conf = agents_conf self.agents_conf = agents_conf
self.verbose = verbose self.verbose = verbose
self.rng = np.random.default_rng(env_seed) self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values())) self.rules = StepRules(*rules)
self.tests = StepTests(*tests)
def __getitem__(self, item): def __getitem__(self, item):
return self.entities[item] return self.entities[item]
@ -82,10 +84,13 @@ class Gamestate(object):
def tick(self, actions) -> List[Result]: def tick(self, actions) -> List[Result]:
results = list() results = list()
test_results = list()
self.curr_step += 1 self.curr_step += 1
# Main Agent Step # Main Agent Step
results.extend(self.rules.tick_pre_step_all(self)) results.extend(self.rules.tick_pre_step_all(self))
if self.tests:
test_results.extend(self.tests.tick_pre_step_all(self))
for idx, action_int in enumerate(actions): for idx, action_int in enumerate(actions):
agent = self[c.AGENT][idx].clear_temp_state() agent = self[c.AGENT][idx].clear_temp_state()
@ -101,6 +106,10 @@ class Gamestate(object):
results.extend(self.rules.tick_step_all(self)) results.extend(self.rules.tick_step_all(self))
results.extend(self.rules.tick_post_step_all(self)) results.extend(self.rules.tick_post_step_all(self))
if self.tests:
test_results.extend(self.tests.tick_step_all(self))
test_results.extend(self.tests.tick_post_step_all(self))
return results return results
def print(self, string): def print(self, string):
@ -133,3 +142,48 @@ class Gamestate(object):
else: else:
return False return False
class StepTests:
def __init__(self, *args):
if args:
self.tests = list(args)
else:
self.tests = list()
def __repr__(self):
return f'Tests{[x.name for x in self]}'
def __iter__(self):
return iter(self.tests)
def append(self, item):
assert isinstance(item, Test)
self.tests.append(item)
return True
def do_all_init(self, state, lvl_map):
for test in self.tests:
if test_init_printline := test.on_init(state, lvl_map):
state.print(test_init_printline)
return c.VALID
def tick_step_all(self, state):
test_results = list()
for test in self.tests:
if tick_step_result := test.tick_step(state):
test_results.extend(tick_step_result)
return test_results
def tick_pre_step_all(self, state):
test_results = list()
for test in self.tests:
if tick_pre_step_result := test.tick_pre_step(state):
test_results.extend(tick_pre_step_result)
return test_results
def tick_post_step_all(self, state):
test_results = list()
for test in self.tests:
if tick_post_step_result := test.tick_post_step(state):
test_results.extend(tick_post_step_result)
return test_results