added fallback action attribute to agents and set standard fallback action to noop

This commit is contained in:
Chanumask 2024-04-29 11:03:59 +02:00
parent 0bbf0dafdb
commit 5ee39eba8d
5 changed files with 16 additions and 5 deletions

View File

@ -36,6 +36,7 @@ class TSPBaseAgent(ABC):
self._position_graph = self.generate_pos_graph()
self._static_route = None
self.cached_route = None
self.fallback_action = None
@abstractmethod
def predict(self, *_, **__) -> int:
@ -170,8 +171,11 @@ class TSPBaseAgent(ABC):
action = next(action for action, pos_diff in MOVEMAP.items() if
np.all(diff == pos_diff) and action in allowed_directions)
except StopIteration:
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
action = choice(self.state.actions).name
print(f"No valid action found for pos diff: {diff}. Using fallback action: {self.fallback_action}.")
if self.fallback_action and any(self.fallback_action == action.name for action in self.state.actions):
action = self.fallback_action
else:
action = choice(self.state.actions).name
else:
action = choice(self.state.actions).name
# noinspection PyUnboundLocalVariable

View File

@ -1,6 +1,7 @@
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
from marl_factory_grid.modules.clean_up import constants as di
from marl_factory_grid.environment import constants as c
future_planning = 7
@ -12,6 +13,7 @@ class TSPDirtAgent(TSPBaseAgent):
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
"""
super(TSPDirtAgent, self).__init__(*args, **kwargs)
self.fallback_action = c.NOOP
def predict(self, *_, **__):
"""

View File

@ -3,6 +3,7 @@ import numpy as np
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
from marl_factory_grid.modules.items import constants as i
from marl_factory_grid.environment import constants as c
future_planning = 7
inventory_size = 3
@ -22,6 +23,7 @@ class TSPItemAgent(TSPBaseAgent):
"""
super(TSPItemAgent, self).__init__(*args, **kwargs)
self.mode = mode
self.fallback_action = c.NOOP
def predict(self, *_, **__):
item_at_position = self._env.state[i.ITEM].by_pos(self.state.pos)

View File

@ -2,6 +2,8 @@ from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
from marl_factory_grid.modules.destinations import constants as d
from marl_factory_grid.modules.doors import constants as do
from marl_factory_grid.environment import constants as c
future_planning = 7
@ -13,6 +15,7 @@ class TSPTargetAgent(TSPBaseAgent):
Initializes a TSPTargetAgent that aims to reach destinations.
"""
super(TSPTargetAgent, self).__init__(*args, **kwargs)
self.fallback_action = c.NOOP
def _handle_doors(self, state):
"""

View File

@ -12,7 +12,7 @@ if __name__ == '__main__':
render = True
# Path to config File
path = Path('marl_factory_grid/configs/test_config.yaml')
path = Path('marl_factory_grid/configs/simple_crossing.yaml')
# Env Init
factory = Factory(path)
@ -23,8 +23,8 @@ if __name__ == '__main__':
if render:
factory.render()
action_spaces = factory.action_space
agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
# agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
# agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
# agents = [TSPTargetAgent(factory, 0)]
while not done:
a = [x.predict() for x in agents]