diff --git a/marl_factory_grid/environment/tests.py b/marl_factory_grid/environment/tests.py index 17592e3..5d1f8dd 100644 --- a/marl_factory_grid/environment/tests.py +++ b/marl_factory_grid/environment/tests.py @@ -17,7 +17,8 @@ class Test(unittest.TestCase): def __init__(self): """ - Base test class for unit tests. + Base test class for unit tests that provides base functions to be overwritten that are automatically called by + the StepTests class. """ super().__init__() @@ -56,7 +57,7 @@ class MaintainerTest(Test): def tick_step(self, state) -> List[TickResult]: for maintainer in state.entities[M.MAINTAINERS]: - # has valid actionresult + # has valid action result self.assertIsInstance(maintainer.state, ActionResult) # self.assertEqual(maintainer.state.validity, True) # print(f"state validity {maintainer.state.validity}") diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index 003d913..b7a3ef4 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -309,6 +309,9 @@ class Gamestate(object): class StepTests: def __init__(self, *args): + """ + The StepTests class is responsible for calling all tests and their respective hooks at the right time. + """ if args: self.tests = list(args) else: @@ -326,6 +329,12 @@ class StepTests: return True def do_all_init(self, state, lvl_map) -> bool: + """ + Iterate all **Tests** that override the *on_check_done* hook. + + :return: valid + :rtype: bool + """ for test in self.tests: if test_init_printline := test.on_init(state, lvl_map): state.print(test_init_printline) @@ -335,11 +344,7 @@ class StepTests: """ Iterate all **Tests** that override the *tick_step* hook. - :return: List of Results - """ """ - Iterate all **Tests** that override the *on_check_done* hook. - - :return: List of Results + :return: List of Results """ test_results = list() for test in self.tests: @@ -351,7 +356,7 @@ class StepTests: """ Iterate all **Tests** that override the *pre_step* hook. - :return: List of Results + :return: List of Results """ test_results = list() for test in self.tests: @@ -363,7 +368,7 @@ class StepTests: """ Iterate all **Tests** that override the *post_step* hook. - :return: List of Results + :return: List of Results """ test_results = list() for test in self.tests: @@ -375,7 +380,7 @@ class StepTests: """ Iterate all **Tests** that override the *on_check_done* hook. - :return: List of Results + :return: List of Results """ test_results = list() for test in self.tests: