mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
added test hooks (like rules)
This commit is contained in:
parent
9b9c6e0385
commit
64c0d0e4e9
@ -87,16 +87,19 @@ class Factory(gym.Env):
|
|||||||
entities = self.map.do_init()
|
entities = self.map.do_init()
|
||||||
|
|
||||||
# Init rules
|
# Init rules
|
||||||
rules = self.conf.load_rules()
|
rules = self.conf.load_env_rules()
|
||||||
|
env_tests = self.conf.load_env_tests() if self.conf.tests else []
|
||||||
|
|
||||||
# Parse the agent conf
|
# Parse the agent conf
|
||||||
parsed_agents_conf = self.conf.parse_agents_conf()
|
parsed_agents_conf = self.conf.parse_agents_conf()
|
||||||
self.state = Gamestate(entities, parsed_agents_conf, rules, self.conf.env_seed, self.conf.verbose)
|
self.state = Gamestate(entities, parsed_agents_conf, rules, env_tests, self.conf.env_seed, self.conf.verbose)
|
||||||
|
|
||||||
# All is set up, trigger entity init with variable pos
|
# All is set up, trigger entity init with variable pos
|
||||||
# All is set up, trigger additional init (after agent entity spawn etc)
|
# All is set up, trigger additional init (after agent entity spawn etc)
|
||||||
self.state.rules.do_all_init(self.state, self.map)
|
self.state.rules.do_all_init(self.state, self.map)
|
||||||
|
|
||||||
|
self.state.tests.do_all_init(self.state, self.map)
|
||||||
|
|
||||||
# Build initial observations for all agents
|
# Build initial observations for all agents
|
||||||
# noinspection PyAttributeOutsideInit
|
# noinspection PyAttributeOutsideInit
|
||||||
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
self.obs_builder = OBSBuilder(self.map.level_shape, self.state, self.map.pomdp_r)
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
import ast
|
import ast
|
||||||
from collections import defaultdict
|
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union, List
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from marl_factory_grid.environment.groups.agents import Agents
|
|
||||||
from marl_factory_grid.environment.entity.agent import Agent
|
|
||||||
from marl_factory_grid.environment.rules import Rule
|
|
||||||
from marl_factory_grid.utils.helpers import locate_and_import_class
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
|
from marl_factory_grid.environment.rules import Rule
|
||||||
|
from marl_factory_grid.environment.tests import Test
|
||||||
|
from marl_factory_grid.utils.helpers import locate_and_import_class
|
||||||
|
|
||||||
DEFAULT_PATH = 'environment'
|
DEFAULT_PATH = 'environment'
|
||||||
MODULE_PATH = 'modules'
|
MODULE_PATH = 'modules'
|
||||||
@ -131,17 +129,25 @@ class FactoryConfigParser(object):
|
|||||||
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
parsed_agents_conf[name] = dict(actions=parsed_actions, observations=observations, positions=positions)
|
||||||
return parsed_agents_conf
|
return parsed_agents_conf
|
||||||
|
|
||||||
def load_rules(self):
|
def load_env_rules(self) -> List[Rule]:
|
||||||
# entites = Entities()
|
rules = self.rules.copy()
|
||||||
rules_classes = dict()
|
|
||||||
rules = []
|
|
||||||
if c.DEFAULTS in self.rules:
|
if c.DEFAULTS in self.rules:
|
||||||
for rule in self.default_rules:
|
for rule in self.default_rules:
|
||||||
if rule not in rules:
|
if rule not in rules:
|
||||||
rules.append(rule)
|
rules.append({rule: {}})
|
||||||
rules.extend(x for x in self.rules if x != c.DEFAULTS)
|
|
||||||
|
|
||||||
for rule in rules:
|
return self._load_smth(rules, Rule)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_env_tests(self) -> List[Test]:
|
||||||
|
return self._load_smth(self.tests, None) # Test
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _load_smth(self, config, class_obj):
|
||||||
|
rules = list()
|
||||||
|
rules_names = list()
|
||||||
|
|
||||||
|
for rule in rules_names:
|
||||||
try:
|
try:
|
||||||
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
folder_path = (Path(__file__).parent.parent / DEFAULT_PATH)
|
||||||
rule_class = locate_and_import_class(rule, folder_path)
|
rule_class = locate_and_import_class(rule, folder_path)
|
||||||
@ -152,7 +158,7 @@ class FactoryConfigParser(object):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
rule_class = locate_and_import_class(rule, self.custom_modules_path)
|
||||||
# Fixme This check does not work!
|
# Fixme This check does not work!
|
||||||
# assert isinstance(rule_class, Rule), f'{rule_class.__name__} is no valid "Rule".'
|
# assert isinstance(rule_class, class_obj), f'{rule_class.__name__} is no valid "class_obj.__name__".'
|
||||||
rule_kwargs = self.rules.get(rule, {})
|
rule_kwargs = config.get(rule, {})
|
||||||
rules_classes.update({rule: {'class': rule_class, 'kwargs': rule_kwargs}})
|
rules.append(rule_class(**rule_kwargs))
|
||||||
return rules_classes
|
return rules
|
||||||
|
@ -4,6 +4,7 @@ import numpy as np
|
|||||||
|
|
||||||
from marl_factory_grid.environment import constants as c
|
from marl_factory_grid.environment import constants as c
|
||||||
from marl_factory_grid.environment.rules import Rule
|
from marl_factory_grid.environment.rules import Rule
|
||||||
|
from marl_factory_grid.environment.tests import Test
|
||||||
from marl_factory_grid.utils.results import Result
|
from marl_factory_grid.utils.results import Result
|
||||||
|
|
||||||
|
|
||||||
@ -59,14 +60,15 @@ class Gamestate(object):
|
|||||||
def moving_entites(self):
|
def moving_entites(self):
|
||||||
return [y for x in self.entities for y in x if x.var_can_move]
|
return [y for x in self.entities for y in x if x.var_can_move]
|
||||||
|
|
||||||
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
|
def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False):
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
self.curr_step = 0
|
self.curr_step = 0
|
||||||
self.curr_actions = None
|
self.curr_actions = None
|
||||||
self.agents_conf = agents_conf
|
self.agents_conf = agents_conf
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.rng = np.random.default_rng(env_seed)
|
self.rng = np.random.default_rng(env_seed)
|
||||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
self.rules = StepRules(*rules)
|
||||||
|
self.tests = StepTests(*tests)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
return self.entities[item]
|
return self.entities[item]
|
||||||
@ -82,10 +84,13 @@ class Gamestate(object):
|
|||||||
|
|
||||||
def tick(self, actions) -> List[Result]:
|
def tick(self, actions) -> List[Result]:
|
||||||
results = list()
|
results = list()
|
||||||
|
test_results = list()
|
||||||
self.curr_step += 1
|
self.curr_step += 1
|
||||||
|
|
||||||
# Main Agent Step
|
# Main Agent Step
|
||||||
results.extend(self.rules.tick_pre_step_all(self))
|
results.extend(self.rules.tick_pre_step_all(self))
|
||||||
|
if self.tests:
|
||||||
|
test_results.extend(self.tests.tick_pre_step_all(self))
|
||||||
|
|
||||||
for idx, action_int in enumerate(actions):
|
for idx, action_int in enumerate(actions):
|
||||||
agent = self[c.AGENT][idx].clear_temp_state()
|
agent = self[c.AGENT][idx].clear_temp_state()
|
||||||
@ -101,6 +106,10 @@ class Gamestate(object):
|
|||||||
results.extend(self.rules.tick_step_all(self))
|
results.extend(self.rules.tick_step_all(self))
|
||||||
results.extend(self.rules.tick_post_step_all(self))
|
results.extend(self.rules.tick_post_step_all(self))
|
||||||
|
|
||||||
|
if self.tests:
|
||||||
|
test_results.extend(self.tests.tick_step_all(self))
|
||||||
|
test_results.extend(self.tests.tick_post_step_all(self))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def print(self, string):
|
def print(self, string):
|
||||||
@ -133,3 +142,48 @@ class Gamestate(object):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class StepTests:
|
||||||
|
def __init__(self, *args):
|
||||||
|
if args:
|
||||||
|
self.tests = list(args)
|
||||||
|
else:
|
||||||
|
self.tests = list()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'Tests{[x.name for x in self]}'
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.tests)
|
||||||
|
|
||||||
|
def append(self, item):
|
||||||
|
assert isinstance(item, Test)
|
||||||
|
self.tests.append(item)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def do_all_init(self, state, lvl_map):
|
||||||
|
for test in self.tests:
|
||||||
|
if test_init_printline := test.on_init(state, lvl_map):
|
||||||
|
state.print(test_init_printline)
|
||||||
|
return c.VALID
|
||||||
|
|
||||||
|
def tick_step_all(self, state):
|
||||||
|
test_results = list()
|
||||||
|
for test in self.tests:
|
||||||
|
if tick_step_result := test.tick_step(state):
|
||||||
|
test_results.extend(tick_step_result)
|
||||||
|
return test_results
|
||||||
|
|
||||||
|
def tick_pre_step_all(self, state):
|
||||||
|
test_results = list()
|
||||||
|
for test in self.tests:
|
||||||
|
if tick_pre_step_result := test.tick_pre_step(state):
|
||||||
|
test_results.extend(tick_pre_step_result)
|
||||||
|
return test_results
|
||||||
|
|
||||||
|
def tick_post_step_all(self, state):
|
||||||
|
test_results = list()
|
||||||
|
for test in self.tests:
|
||||||
|
if tick_post_step_result := test.tick_post_step(state):
|
||||||
|
test_results.extend(tick_post_step_result)
|
||||||
|
return test_results
|
||||||
|
Loading…
x
Reference in New Issue
Block a user