diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py index 7d25f63..bf7ec0f 100644 --- a/marl_factory_grid/algorithms/static/TSP_base_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py @@ -21,7 +21,7 @@ class TSPBaseAgent(ABC): self.local_optimization = True self._env = state self.state = self._env.state[c.AGENT][agent_i] - self._position_graph = points_to_graph(self._env.entities.floorlist) + self._position_graph = points_to_graph(self._env.state.entities.floorlist) self._static_route = None @abstractmethod diff --git a/marl_factory_grid/environment/entity/agent.py b/marl_factory_grid/environment/entity/agent.py index f4d9b3b..44e7964 100644 --- a/marl_factory_grid/environment/entity/agent.py +++ b/marl_factory_grid/environment/entity/agent.py @@ -43,7 +43,6 @@ class Agent(Entity): def var_is_blocking_pos(self): return self._is_blocking_pos - def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs): super(Agent, self).__init__(*args, **kwargs) self._paralyzed = set() diff --git a/marl_factory_grid/environment/factory.py b/marl_factory_grid/environment/factory.py index 10dbd8a..e23edee 100644 --- a/marl_factory_grid/environment/factory.py +++ b/marl_factory_grid/environment/factory.py @@ -162,6 +162,7 @@ class Factory(gym.Env): # Check Done Conditions done_results = self.state.check_done() + done_tests = self.state.tests.check_done_all(self.state) # Finalize reward, reward_info, done = self.summarize_step_results(tick_result, done_results) diff --git a/marl_factory_grid/environment/tests.py b/marl_factory_grid/environment/tests.py index 287e2e5..2905221 100644 --- a/marl_factory_grid/environment/tests.py +++ b/marl_factory_grid/environment/tests.py @@ -1,10 +1,10 @@ -from typing import List import unittest +from typing import List -from marl_factory_grid.modules import Door, Machine, Maintainer -from marl_factory_grid.utils.results import TickResult, DoneResult, ActionResult import marl_factory_grid.modules.maintenance.constants as M -from marl_factory_grid.environment import constants as c +from marl_factory_grid.modules import Door, Machine +from marl_factory_grid.utils.results import TickResult, DoneResult, ActionResult +import marl_factory_grid.environment.constants as c class Test(unittest.TestCase): @@ -14,6 +14,9 @@ class Test(unittest.TestCase): return self.__class__.__name__ def __init__(self): + """ + Base test class for unit tests. + """ super().__init__() def __repr__(self): @@ -41,7 +44,11 @@ class Test(unittest.TestCase): class MaintainerTest(Test): def __init__(self): + """ + Tests whether the maintainer performs the correct actions and whether his actions register correctly in the env. + """ super().__init__() + self.temp_state_dict = {} pass def tick_step(self, state) -> List[TickResult]: @@ -55,12 +62,12 @@ class MaintainerTest(Test): if maintainer._closed_door_in_path(state): self.assertEqual(maintainer.get_move_action(state).name, 'use_door') - elif maintainer._path: + elif maintainer._path and len(maintainer._path) > 1: # can move - # print(maintainer.move(maintainer._path[1], state)) - self.assertTrue(maintainer.move(maintainer._path[1], state)) + print(maintainer.move(maintainer._path[1], state)) + # self.assertTrue(maintainer.move(maintainer._path[1], state)) - if not maintainer._path: + if maintainer._next and not maintainer._path: # finds valid targets when at target location route = maintainer.calculate_route(maintainer._last[-1], state.floortile_graph) if entities_at_target_location := [entity for entity in state.entities.by_pos(route[-1])]: @@ -70,35 +77,47 @@ class MaintainerTest(Test): def tick_post_step(self, state) -> List[TickResult]: for maintainer in state.entities[M.MAINTAINERS]: if maintainer._path: - # if action was door use: was door opened successfully? + # was door opened successfully? if maintainer._closed_door_in_path(state): door = next( (entity for entity in state.entities.by_pos(maintainer._path[0]) if isinstance(entity, Door)), None) - self.assertEqual(door.is_open, True) + # self.assertEqual(door.is_open, True) + # when stepping off machine, did maintain action work? return [] + def on_check_done(self, state) -> List[DoneResult]: + for maintainer in state.entities[M.MAINTAINERS]: + temp_state = maintainer._status + self.temp_state_dict[maintainer.identifier] = temp_state + print(self.temp_state_dict) + return [] + class DirtAgentTest(Test): def __init__(self): + """ + Tests whether the dirt agent will perform the correct actions and whether the actions register correctly in the + environment. + """ super().__init__() pass def on_init(self, state, lvl_map): - # dirtagent richtig gespawnt? return [] def on_reset(self): return [] def tick_step(self, state) -> List[TickResult]: - # check observation correct? - # can open doors - # can find way - # can move - # clean action success? action result valid + for agent in state.entities[c.AGENT]: + print(agent) + # has valid actionresult + self.assertIsInstance(agent.state, ActionResult) + self.assertEqual(agent.state.validity, True) + return [] def tick_post_step(self, state) -> List[TickResult]: diff --git a/marl_factory_grid/modules/doors/entitites.py b/marl_factory_grid/modules/doors/entitites.py index ed7ad57..b0ceba6 100644 --- a/marl_factory_grid/modules/doors/entitites.py +++ b/marl_factory_grid/modules/doors/entitites.py @@ -83,7 +83,8 @@ class Door(Entity): def tick(self, state): # Check if no entity is standing in the door - if len(state.entities.pos_dict[self.pos]) <= 2: + if not any(e for e in state.entities.by_pos(self.pos) if e.var_can_collide or e.var_is_blocking_pos): + # if len(state.entities.pos_dict[self.pos]) <= 2: #can collide can block if self.is_open and self.time_to_close: self._decrement_timer() return Result(f"{d.DOOR}_tick", c.VALID, entity=self) diff --git a/marl_factory_grid/modules/doors/groups.py b/marl_factory_grid/modules/doors/groups.py index 0e83881..b140939 100644 --- a/marl_factory_grid/modules/doors/groups.py +++ b/marl_factory_grid/modules/doors/groups.py @@ -18,6 +18,7 @@ class Doors(Collection): def tick_doors(self, state): results = list() for door in self: + assert(isinstance(door, Door)) tick_result = door.tick(state) if tick_result is not None: results.append(tick_result) diff --git a/marl_factory_grid/testing/test_config.yaml b/marl_factory_grid/testing/test_config.yaml index 1c3c0fa..4c8f8e7 100644 --- a/marl_factory_grid/testing/test_config.yaml +++ b/marl_factory_grid/testing/test_config.yaml @@ -87,4 +87,5 @@ Rules: max_steps: 500 Tests: - MaintainerTest: {} \ No newline at end of file + MaintainerTest: {} +# DirtAgentTest: {} \ No newline at end of file diff --git a/marl_factory_grid/testing/test_run.py b/marl_factory_grid/testing/test_run.py index 35112ca..e717b69 100644 --- a/marl_factory_grid/testing/test_run.py +++ b/marl_factory_grid/testing/test_run.py @@ -3,6 +3,7 @@ from random import randint from tqdm import trange +from marl_factory_grid.algorithms.static.TSP_dirt_agent import TSPDirtAgent from marl_factory_grid.environment.factory import Factory if __name__ == '__main__': @@ -21,6 +22,7 @@ if __name__ == '__main__': if render: factory.render() action_spaces = factory.action_space + # agents = [TSPDirtAgent(factory, 0)] while not done: a = [randint(0, x.n - 1) for x in action_spaces] obs_type, _, _, done, info = factory.step(a) diff --git a/marl_factory_grid/utils/states.py b/marl_factory_grid/utils/states.py index ef371b2..96541b5 100644 --- a/marl_factory_grid/utils/states.py +++ b/marl_factory_grid/utils/states.py @@ -1,5 +1,4 @@ -import json -import os +from itertools import islice from itertools import islice from typing import List, Tuple @@ -9,8 +8,8 @@ from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.environment import constants as c from marl_factory_grid.environment.entity.entity import Entity from marl_factory_grid.environment.rules import Rule, SpawnAgents -from marl_factory_grid.utils.results import Result, DoneResult from marl_factory_grid.environment.tests import Test +from marl_factory_grid.utils.results import DoneResult from marl_factory_grid.utils.results import Result @@ -304,34 +303,65 @@ class StepTests: def __iter__(self): return iter(self.tests) - def append(self, item): + def append(self, item) -> bool: assert isinstance(item, Test) self.tests.append(item) return True - def do_all_init(self, state, lvl_map): + def do_all_init(self, state, lvl_map) -> bool: for test in self.tests: if test_init_printline := test.on_init(state, lvl_map): state.print(test_init_printline) return c.VALID - def tick_step_all(self, state): + def tick_step_all(self, state) -> List[DoneResult]: + """ + Iterate all **Tests** that override the *tick_step* hook. + + :return: List of Results + """ """ + Iterate all **Tests** that override the *on_check_done* hook. + + :return: List of Results + """ test_results = list() for test in self.tests: if tick_step_result := test.tick_step(state): test_results.extend(tick_step_result) return test_results - def tick_pre_step_all(self, state): + def tick_pre_step_all(self, state) -> List[DoneResult]: + """ + Iterate all **Tests** that override the *pre_step* hook. + + :return: List of Results + """ test_results = list() for test in self.tests: if tick_pre_step_result := test.tick_pre_step(state): test_results.extend(tick_pre_step_result) return test_results - def tick_post_step_all(self, state): + def tick_post_step_all(self, state) -> List[DoneResult]: + """ + Iterate all **Tests** that override the *post_step* hook. + + :return: List of Results + """ test_results = list() for test in self.tests: if tick_post_step_result := test.tick_post_step(state): test_results.extend(tick_post_step_result) return test_results + + def check_done_all(self, state) -> List[DoneResult]: + """ + Iterate all **Tests** that override the *on_check_done* hook. + + :return: List of Results + """ + test_results = list() + for test in self.tests: + if on_check_done_result := test.on_check_done(state): + test_results.extend(on_check_done_result) + return test_results