mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
added test hooks (like rules)
This commit is contained in:
@ -4,6 +4,7 @@ import numpy as np
|
||||
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.rules import Rule
|
||||
from marl_factory_grid.environment.tests import Test
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
@ -59,14 +60,15 @@ class Gamestate(object):
|
||||
def moving_entites(self):
|
||||
return [y for x in self.entities for y in x if x.var_can_move]
|
||||
|
||||
def __init__(self, entities, agents_conf, rules: Dict[str, dict], env_seed=69, verbose=False):
|
||||
def __init__(self, entities, agents_conf, rules: [Rule], tests: [Test], env_seed=69, verbose=False):
|
||||
self.entities = entities
|
||||
self.curr_step = 0
|
||||
self.curr_actions = None
|
||||
self.agents_conf = agents_conf
|
||||
self.verbose = verbose
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*(v['class'](**v['kwargs']) for v in rules.values()))
|
||||
self.rules = StepRules(*rules)
|
||||
self.tests = StepTests(*tests)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.entities[item]
|
||||
@ -82,10 +84,13 @@ class Gamestate(object):
|
||||
|
||||
def tick(self, actions) -> List[Result]:
|
||||
results = list()
|
||||
test_results = list()
|
||||
self.curr_step += 1
|
||||
|
||||
# Main Agent Step
|
||||
results.extend(self.rules.tick_pre_step_all(self))
|
||||
if self.tests:
|
||||
test_results.extend(self.tests.tick_pre_step_all(self))
|
||||
|
||||
for idx, action_int in enumerate(actions):
|
||||
agent = self[c.AGENT][idx].clear_temp_state()
|
||||
@ -101,6 +106,10 @@ class Gamestate(object):
|
||||
results.extend(self.rules.tick_step_all(self))
|
||||
results.extend(self.rules.tick_post_step_all(self))
|
||||
|
||||
if self.tests:
|
||||
test_results.extend(self.tests.tick_step_all(self))
|
||||
test_results.extend(self.tests.tick_post_step_all(self))
|
||||
|
||||
return results
|
||||
|
||||
def print(self, string):
|
||||
@ -133,3 +142,48 @@ class Gamestate(object):
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class StepTests:
|
||||
def __init__(self, *args):
|
||||
if args:
|
||||
self.tests = list(args)
|
||||
else:
|
||||
self.tests = list()
|
||||
|
||||
def __repr__(self):
|
||||
return f'Tests{[x.name for x in self]}'
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tests)
|
||||
|
||||
def append(self, item):
|
||||
assert isinstance(item, Test)
|
||||
self.tests.append(item)
|
||||
return True
|
||||
|
||||
def do_all_init(self, state, lvl_map):
|
||||
for test in self.tests:
|
||||
if test_init_printline := test.on_init(state, lvl_map):
|
||||
state.print(test_init_printline)
|
||||
return c.VALID
|
||||
|
||||
def tick_step_all(self, state):
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_step_result := test.tick_step(state):
|
||||
test_results.extend(tick_step_result)
|
||||
return test_results
|
||||
|
||||
def tick_pre_step_all(self, state):
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_pre_step_result := test.tick_pre_step(state):
|
||||
test_results.extend(tick_pre_step_result)
|
||||
return test_results
|
||||
|
||||
def tick_post_step_all(self, state):
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_post_step_result := test.tick_post_step(state):
|
||||
test_results.extend(tick_post_step_result)
|
||||
return test_results
|
||||
|
Reference in New Issue
Block a user