mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
maintainer test grabs temp state
This commit is contained in:
parent
3a7b727ec6
commit
4c85cac50f
@ -21,7 +21,7 @@ class TSPBaseAgent(ABC):
|
||||
self.local_optimization = True
|
||||
self._env = state
|
||||
self.state = self._env.state[c.AGENT][agent_i]
|
||||
self._position_graph = points_to_graph(self._env.entities.floorlist)
|
||||
self._position_graph = points_to_graph(self._env.state.entities.floorlist)
|
||||
self._static_route = None
|
||||
|
||||
@abstractmethod
|
||||
|
@ -43,7 +43,6 @@ class Agent(Entity):
|
||||
def var_is_blocking_pos(self):
|
||||
return self._is_blocking_pos
|
||||
|
||||
|
||||
def __init__(self, actions: List[Action], observations: List[str], *args, is_blocking_pos=False, **kwargs):
|
||||
super(Agent, self).__init__(*args, **kwargs)
|
||||
self._paralyzed = set()
|
||||
|
@ -162,6 +162,7 @@ class Factory(gym.Env):
|
||||
|
||||
# Check Done Conditions
|
||||
done_results = self.state.check_done()
|
||||
done_tests = self.state.tests.check_done_all(self.state)
|
||||
|
||||
# Finalize
|
||||
reward, reward_info, done = self.summarize_step_results(tick_result, done_results)
|
||||
|
@ -1,10 +1,10 @@
|
||||
from typing import List
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from marl_factory_grid.modules import Door, Machine, Maintainer
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult, ActionResult
|
||||
import marl_factory_grid.modules.maintenance.constants as M
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.modules import Door, Machine
|
||||
from marl_factory_grid.utils.results import TickResult, DoneResult, ActionResult
|
||||
import marl_factory_grid.environment.constants as c
|
||||
|
||||
|
||||
class Test(unittest.TestCase):
|
||||
@ -14,6 +14,9 @@ class Test(unittest.TestCase):
|
||||
return self.__class__.__name__
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Base test class for unit tests.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
@ -41,7 +44,11 @@ class Test(unittest.TestCase):
|
||||
class MaintainerTest(Test):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Tests whether the maintainer performs the correct actions and whether his actions register correctly in the env.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temp_state_dict = {}
|
||||
pass
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
@ -55,12 +62,12 @@ class MaintainerTest(Test):
|
||||
if maintainer._closed_door_in_path(state):
|
||||
self.assertEqual(maintainer.get_move_action(state).name, 'use_door')
|
||||
|
||||
elif maintainer._path:
|
||||
elif maintainer._path and len(maintainer._path) > 1:
|
||||
# can move
|
||||
# print(maintainer.move(maintainer._path[1], state))
|
||||
self.assertTrue(maintainer.move(maintainer._path[1], state))
|
||||
print(maintainer.move(maintainer._path[1], state))
|
||||
# self.assertTrue(maintainer.move(maintainer._path[1], state))
|
||||
|
||||
if not maintainer._path:
|
||||
if maintainer._next and not maintainer._path:
|
||||
# finds valid targets when at target location
|
||||
route = maintainer.calculate_route(maintainer._last[-1], state.floortile_graph)
|
||||
if entities_at_target_location := [entity for entity in state.entities.by_pos(route[-1])]:
|
||||
@ -70,35 +77,47 @@ class MaintainerTest(Test):
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
for maintainer in state.entities[M.MAINTAINERS]:
|
||||
if maintainer._path:
|
||||
# if action was door use: was door opened successfully?
|
||||
# was door opened successfully?
|
||||
if maintainer._closed_door_in_path(state):
|
||||
door = next(
|
||||
(entity for entity in state.entities.by_pos(maintainer._path[0]) if isinstance(entity, Door)),
|
||||
None)
|
||||
self.assertEqual(door.is_open, True)
|
||||
# self.assertEqual(door.is_open, True)
|
||||
|
||||
# when stepping off machine, did maintain action work?
|
||||
return []
|
||||
|
||||
def on_check_done(self, state) -> List[DoneResult]:
|
||||
for maintainer in state.entities[M.MAINTAINERS]:
|
||||
temp_state = maintainer._status
|
||||
self.temp_state_dict[maintainer.identifier] = temp_state
|
||||
print(self.temp_state_dict)
|
||||
return []
|
||||
|
||||
|
||||
class DirtAgentTest(Test):
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Tests whether the dirt agent will perform the correct actions and whether the actions register correctly in the
|
||||
environment.
|
||||
"""
|
||||
super().__init__()
|
||||
pass
|
||||
|
||||
def on_init(self, state, lvl_map):
|
||||
# dirtagent richtig gespawnt?
|
||||
return []
|
||||
|
||||
def on_reset(self):
|
||||
return []
|
||||
|
||||
def tick_step(self, state) -> List[TickResult]:
|
||||
# check observation correct?
|
||||
# can open doors
|
||||
# can find way
|
||||
# can move
|
||||
# clean action success? action result valid
|
||||
for agent in state.entities[c.AGENT]:
|
||||
print(agent)
|
||||
# has valid actionresult
|
||||
self.assertIsInstance(agent.state, ActionResult)
|
||||
self.assertEqual(agent.state.validity, True)
|
||||
|
||||
return []
|
||||
|
||||
def tick_post_step(self, state) -> List[TickResult]:
|
||||
|
@ -83,7 +83,8 @@ class Door(Entity):
|
||||
|
||||
def tick(self, state):
|
||||
# Check if no entity is standing in the door
|
||||
if len(state.entities.pos_dict[self.pos]) <= 2:
|
||||
if not any(e for e in state.entities.by_pos(self.pos) if e.var_can_collide or e.var_is_blocking_pos):
|
||||
# if len(state.entities.pos_dict[self.pos]) <= 2: #can collide can block
|
||||
if self.is_open and self.time_to_close:
|
||||
self._decrement_timer()
|
||||
return Result(f"{d.DOOR}_tick", c.VALID, entity=self)
|
||||
|
@ -18,6 +18,7 @@ class Doors(Collection):
|
||||
def tick_doors(self, state):
|
||||
results = list()
|
||||
for door in self:
|
||||
assert(isinstance(door, Door))
|
||||
tick_result = door.tick(state)
|
||||
if tick_result is not None:
|
||||
results.append(tick_result)
|
||||
|
@ -88,3 +88,4 @@ Rules:
|
||||
|
||||
Tests:
|
||||
MaintainerTest: {}
|
||||
# DirtAgentTest: {}
|
@ -3,6 +3,7 @@ from random import randint
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
from marl_factory_grid.algorithms.static.TSP_dirt_agent import TSPDirtAgent
|
||||
from marl_factory_grid.environment.factory import Factory
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -21,6 +22,7 @@ if __name__ == '__main__':
|
||||
if render:
|
||||
factory.render()
|
||||
action_spaces = factory.action_space
|
||||
# agents = [TSPDirtAgent(factory, 0)]
|
||||
while not done:
|
||||
a = [randint(0, x.n - 1) for x in action_spaces]
|
||||
obs_type, _, _, done, info = factory.step(a)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from itertools import islice
|
||||
from itertools import islice
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -9,8 +8,8 @@ from marl_factory_grid.algorithms.static.utils import points_to_graph
|
||||
from marl_factory_grid.environment import constants as c
|
||||
from marl_factory_grid.environment.entity.entity import Entity
|
||||
from marl_factory_grid.environment.rules import Rule, SpawnAgents
|
||||
from marl_factory_grid.utils.results import Result, DoneResult
|
||||
from marl_factory_grid.environment.tests import Test
|
||||
from marl_factory_grid.utils.results import DoneResult
|
||||
from marl_factory_grid.utils.results import Result
|
||||
|
||||
|
||||
@ -304,34 +303,65 @@ class StepTests:
|
||||
def __iter__(self):
|
||||
return iter(self.tests)
|
||||
|
||||
def append(self, item):
|
||||
def append(self, item) -> bool:
|
||||
assert isinstance(item, Test)
|
||||
self.tests.append(item)
|
||||
return True
|
||||
|
||||
def do_all_init(self, state, lvl_map):
|
||||
def do_all_init(self, state, lvl_map) -> bool:
|
||||
for test in self.tests:
|
||||
if test_init_printline := test.on_init(state, lvl_map):
|
||||
state.print(test_init_printline)
|
||||
return c.VALID
|
||||
|
||||
def tick_step_all(self, state):
|
||||
def tick_step_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *tick_step* hook.
|
||||
|
||||
:return: List of Results
|
||||
""" """
|
||||
Iterate all **Tests** that override the *on_check_done* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_step_result := test.tick_step(state):
|
||||
test_results.extend(tick_step_result)
|
||||
return test_results
|
||||
|
||||
def tick_pre_step_all(self, state):
|
||||
def tick_pre_step_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *pre_step* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_pre_step_result := test.tick_pre_step(state):
|
||||
test_results.extend(tick_pre_step_result)
|
||||
return test_results
|
||||
|
||||
def tick_post_step_all(self, state):
|
||||
def tick_post_step_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *post_step* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if tick_post_step_result := test.tick_post_step(state):
|
||||
test_results.extend(tick_post_step_result)
|
||||
return test_results
|
||||
|
||||
def check_done_all(self, state) -> List[DoneResult]:
|
||||
"""
|
||||
Iterate all **Tests** that override the *on_check_done* hook.
|
||||
|
||||
:return: List of Results
|
||||
"""
|
||||
test_results = list()
|
||||
for test in self.tests:
|
||||
if on_check_done_result := test.on_check_done(state):
|
||||
test_results.extend(on_check_done_result)
|
||||
return test_results
|
||||
|
Loading…
x
Reference in New Issue
Block a user