diff --git a/marl_factory_grid/configs/test_config.yaml b/marl_factory_grid/configs/test_config.yaml index cf5e2d5..ca328e9 100644 --- a/marl_factory_grid/configs/test_config.yaml +++ b/marl_factory_grid/configs/test_config.yaml @@ -1,27 +1,44 @@ Agents: - Wolfgang: + Clean test agent: Actions: - - Noop - - Charge - - Clean - - DestAction - - DoorUse - - ItemAction - - Move8 + - Noop + - Charge + - Clean + - DoorUse + - Move8 Observations: - - Combined: - - Other - - Walls - - GlobalPosition - - Battery - - ChargePods - - DirtPiles - - Destinations - - Doors - - Items - - Inventory - - DropOffLocations - - Maintainers + - Combined: + - Other + - Walls + - GlobalPosition + - Battery + - ChargePods + - DirtPiles + - Destinations + - Doors + - Maintainers + Item test agent: + Actions: + - Noop + - Charge + - DestAction + - DoorUse + - ItemAction + - Move8 + Observations: + - Combined: + - Other + - Walls + - GlobalPosition + - Battery + - ChargePods + - Destinations + - Doors + - Items + - Inventory + - DropOffLocations + - Maintainers + Entities: Batteries: @@ -89,4 +106,5 @@ Rules: Tests: MaintainerTest: {} -# DirtAgentTest: {} \ No newline at end of file + DirtAgentTest: {} + ItemAgentTest: {} diff --git a/marl_factory_grid/environment/tests.py b/marl_factory_grid/environment/tests.py index 7a5d0d4..17592e3 100644 --- a/marl_factory_grid/environment/tests.py +++ b/marl_factory_grid/environment/tests.py @@ -2,7 +2,9 @@ import unittest from typing import List import marl_factory_grid.modules.maintenance.constants as M -from marl_factory_grid.modules import Door, Machine +from marl_factory_grid.algorithms.static.TSP_dirt_agent import TSPDirtAgent +from marl_factory_grid.environment.entity.agent import Agent +from marl_factory_grid.modules import Door, Machine, DirtPile, Item, DropOffLocation, ItemAction from marl_factory_grid.utils.results import TickResult, DoneResult, ActionResult import marl_factory_grid.environment.constants as c @@ -56,18 +58,13 @@ class MaintainerTest(Test): # has valid actionresult self.assertIsInstance(maintainer.state, ActionResult) - self.assertEqual(maintainer.state.validity, True) + # self.assertEqual(maintainer.state.validity, True) # print(f"state validity {maintainer.state.validity}") # will open doors when standing in front if maintainer._closed_door_in_path(state): self.assertEqual(maintainer.get_move_action(state).name, 'use_door') - # elif maintainer._path and len(maintainer._path) > 1: - # can move - # print(f"maintainer move: {maintainer.move(maintainer._path[1], state)}") - # self.assertTrue(maintainer.move(maintainer._path[1], state)) - # if maintainer._next and not maintainer._path: # finds valid targets when at target location # route = maintainer.calculate_route(maintainer._last[-1], state.floortile_graph) @@ -76,19 +73,17 @@ class MaintainerTest(Test): return [] def tick_post_step(self, state) -> List[TickResult]: - # do maintainers actions have correct effects on environment i.e. doors open, machines heal + # do maintainers' actions have correct effects on environment i.e. doors open, machines heal for maintainer in state.entities[M.MAINTAINERS]: if maintainer._path and self.temp_state_dict != {}: last_action = self.temp_state_dict[maintainer.identifier] - print(last_action.identifier) if last_action.identifier == 'DoorUse': if door := next((entity for entity in state.entities.get_entities_near_pos(maintainer.pos) if isinstance(entity, Door)), None): self.assertTrue(door.is_open) if last_action.identifier == 'MachineAction': if machine := next((entity for entity in state.entities.get_entities_near_pos(maintainer.pos) if - isinstance(entity, Machine)), None): - print(f"machine hp: {machine.health}") + isinstance(entity, Machine)), None): self.assertEqual(machine.health, 100) return [] @@ -107,6 +102,7 @@ class DirtAgentTest(Test): environment. """ super().__init__() + self.temp_state_dict = {} pass def on_init(self, state, lvl_map): @@ -116,20 +112,101 @@ class DirtAgentTest(Test): return [] def tick_step(self, state) -> List[TickResult]: - for agent in state.entities[c.AGENT]: - print(agent) + for dirtagent in [a for a in state.entities[c.AGENT] if "Clean" in a.identifier]: # isinstance TSPDirtAgent # has valid actionresult - self.assertIsInstance(agent.state, ActionResult) + self.assertIsInstance(dirtagent.state, ActionResult) # self.assertEqual(agent.state.validity, True) + # print(f"state validity {maintainer.state.validity}") return [] def tick_post_step(self, state) -> List[TickResult]: - # action success? - # collisions? if yes, reported? + # do agents' actions have correct effects on environment i.e. doors open, dirt is cleaned + for dirtagent in [a for a in state.entities[c.AGENT] if "Clean" in a.identifier]: # isinstance TSPDirtAgent + if self.temp_state_dict != {}: # and + last_action = self.temp_state_dict[dirtagent.identifier] + if last_action.identifier == 'DoorUse': + if door := next((entity for entity in state.entities.get_entities_near_pos(dirtagent.pos) if + isinstance(entity, Door)), None): + self.assertTrue(door.is_open) # TODO catch if someone closes a door in same moment + if last_action.identifier == 'Clean': + if dirt := next((entity for entity in state.entities.get_entities_near_pos(dirtagent.pos) if + isinstance(entity, DirtPile)), None): + # print(f"dirt left on pos: {dirt.amount}") + self.assertTrue( + dirt.amount < 5) # TODO amount one step before - clean amount? return [] def on_check_done(self, state) -> List[DoneResult]: + for dirtagent in [a for a in state.entities[c.AGENT] if "Clean" in a.identifier]: # isinstance TSPDirtAgent + temp_state = dirtagent._status + self.temp_state_dict[dirtagent.identifier] = temp_state return [] -# class ItemAgentTest(Test): + +class ItemAgentTest(Test): + + def __init__(self): + """ + Tests whether the dirt agent will perform the correct actions and whether the actions register correctly in the + environment. + """ + super().__init__() + self.temp_state_dict = {} + pass + + def on_init(self, state, lvl_map): + return [] + + def on_reset(self): + return [] + + def tick_step(self, state) -> List[TickResult]: + for itemagent in [a for a in state.entities[c.AGENT] if "Item" in a.identifier]: # isinstance TSPItemAgent + # has valid actionresult + self.assertIsInstance(itemagent.state, ActionResult) + # self.assertEqual(agent.state.validity, True) + # print(f"state validity {maintainer.state.validity}") + + return [] + + def tick_post_step(self, state) -> List[TickResult]: + # do agents' actions have correct effects on environment i.e. doors open, items are picked up and dropped off + for itemagent in [a for a in state.entities[c.AGENT] if "Item" in a.identifier]: # isinstance TSPItemAgent + + if self.temp_state_dict != {}: # and + last_action = self.temp_state_dict[itemagent.identifier] + if last_action.identifier == 'DoorUse': + if door := next((entity for entity in state.entities.get_entities_near_pos(itemagent.pos) if + isinstance(entity, Door)), None): + self.assertTrue(door.is_open) + if last_action.identifier == 'ItemAction': + + print(last_action.valid_drop_off_reward) # kann man das nehmen für dropoff vs pickup? + # valid pickup? + + # If it was a pick-up action + nearby_items = [e for e in state.entities.get_entities_near_pos(itemagent.pos) if + isinstance(e, Item)] + self.assertNotIn(Item, nearby_items) + + # If the agent has the item in its inventory + self.assertTrue(itemagent.bound_entity) + + # If it was a drop-off action + nearby_drop_offs = [e for e in state.entities.get_entities_near_pos(itemagent.pos) if + isinstance(e, DropOffLocation)] + if nearby_drop_offs: + dol = nearby_drop_offs[0] + self.assertTrue(dol.bound_entity) # item in drop-off location? + + # Ensure the item is not in the inventory after dropping off + self.assertNotIn(Item, state.entities.get_entities_near_pos(itemagent.pos)) + + return [] + + def on_check_done(self, state) -> List[DoneResult]: + for itemagent in [a for a in state.entities[c.AGENT] if "Item" in a.identifier]: # isinstance TSPItemAgent + temp_state = itemagent._status + self.temp_state_dict[itemagent.identifier] = temp_state + return [] diff --git a/test_run.py b/test_run.py index 9d1f07e..8aa2eaa 100644 --- a/test_run.py +++ b/test_run.py @@ -4,6 +4,7 @@ from random import randint from tqdm import trange from marl_factory_grid.algorithms.static.TSP_dirt_agent import TSPDirtAgent +from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent from marl_factory_grid.environment.factory import Factory if __name__ == '__main__': @@ -22,7 +23,7 @@ if __name__ == '__main__': if render: factory.render() action_spaces = factory.action_space - agents = [TSPDirtAgent(factory, 0)] + agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1)] while not done: a = [randint(0, x.n - 1) for x in action_spaces] obs_type, _, _, done, info = factory.step(a)