mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 00:51:35 +02:00
maintainer test grabs temp state
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from itertools import islice
|
||||
from itertools import islice
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -9,8 +8,8 @@ from marl_factory_grid.algorithms.static.utils import points_to_graph
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.rules import Rule, SpawnAgents
|
||||
from marl_factory_grid.utils.results import Result, DoneResult
|
||||
from marl_factory_grid.environment.tests import Test
|
||||
from marl_factory_grid.utils.results import DoneResult
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
@ -304,34 +303,65 @@ class StepTests:
|
||||
def __iter__(self):
|
||||
return iter(self.tests)
|
||||
|
||||
def append(self, item):
|
||||
def append(self, item) -> bool:
|
||||
assert isinstance(item, Test)
|
||||
self.tests.append(item)
|
||||
return True
|
||||
|
||||
def do_all_init(self, state, lvl_map):
|
||||
def do_all_init(self, state, lvl_map) -> bool:
|
||||
for test in self.tests:
|
||||
if test_init_printline := test.on_init(state, lvl_map):
|
||||
state.print(test_init_printline)
|
||||
return c.VALID
|
||||
|
||||
def tick_step_all(self, state):
|
||||
def tick_step_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *tick_step* hook.
|
||||
|
||||
:return: List of Results
|
||||
""" """
|
||||
Iterate all **Tests** that override the *on_check_done* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_step_result := test.tick_step(state):
|
||||
test_results.extend(tick_step_result)
|
||||
return test_results
|
||||
|
||||
def tick_pre_step_all(self, state):
|
||||
def tick_pre_step_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *pre_step* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_pre_step_result := test.tick_pre_step(state):
|
||||
test_results.extend(tick_pre_step_result)
|
||||
return test_results
|
||||
|
||||
def tick_post_step_all(self, state):
|
||||
def tick_post_step_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *post_step* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_post_step_result := test.tick_post_step(state):
|
||||
test_results.extend(tick_post_step_result)
|
||||
return test_results
|
||||
|
||||
def check_done_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *on_check_done* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if on_check_done_result := test.on_check_done(state):
|
||||
test_results.extend(on_check_done_result)
|
||||
return test_results
|
||||
|
Reference in New Issue
Block a user