mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-22 23:06:43 +02:00
added fallback action attribute to agents and set standard fallback action to noop
This commit is contained in:
parent
0bbf0dafdb
commit
5ee39eba8d
@ -36,6 +36,7 @@ class TSPBaseAgent(ABC):
|
|||||||
self._position_graph = self.generate_pos_graph()
|
self._position_graph = self.generate_pos_graph()
|
||||||
self._static_route = None
|
self._static_route = None
|
||||||
self.cached_route = None
|
self.cached_route = None
|
||||||
|
self.fallback_action = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, *_, **__) -> int:
|
def predict(self, *_, **__) -> int:
|
||||||
@ -170,8 +171,11 @@ class TSPBaseAgent(ABC):
|
|||||||
action = next(action for action, pos_diff in MOVEMAP.items() if
|
action = next(action for action, pos_diff in MOVEMAP.items() if
|
||||||
np.all(diff == pos_diff) and action in allowed_directions)
|
np.all(diff == pos_diff) and action in allowed_directions)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
print(f"No valid action found for pos diff: {diff}. Using fallback action: {self.fallback_action}.")
|
||||||
action = choice(self.state.actions).name
|
if self.fallback_action and any(self.fallback_action == action.name for action in self.state.actions):
|
||||||
|
action = self.fallback_action
|
||||||
|
else:
|
||||||
|
action = choice(self.state.actions).name
|
||||||
else:
|
else:
|
||||||
action = choice(self.state.actions).name
|
action = choice(self.state.actions).name
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||||
|
|
||||||
from marl_factory_grid.modules.clean_up import constants as di
|
from marl_factory_grid.modules.clean_up import constants as di
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
future_planning = 7
|
future_planning = 7
|
||||||
|
|
||||||
@ -12,6 +13,7 @@ class TSPDirtAgent(TSPBaseAgent):
|
|||||||
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
|
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
|
||||||
"""
|
"""
|
||||||
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
||||||
|
self.fallback_action = c.NOOP
|
||||||
|
|
||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
"""
|
"""
|
||||||
|
@ -3,6 +3,7 @@ import numpy as np
|
|||||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||||
|
|
||||||
from marl_factory_grid.modules.items import constants as i
|
from marl_factory_grid.modules.items import constants as i
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
future_planning = 7
|
future_planning = 7
|
||||||
inventory_size = 3
|
inventory_size = 3
|
||||||
@ -22,6 +23,7 @@ class TSPItemAgent(TSPBaseAgent):
|
|||||||
"""
|
"""
|
||||||
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
self.fallback_action = c.NOOP
|
||||||
|
|
||||||
def predict(self, *_, **__):
|
def predict(self, *_, **__):
|
||||||
item_at_position = self._env.state[i.ITEM].by_pos(self.state.pos)
|
item_at_position = self._env.state[i.ITEM].by_pos(self.state.pos)
|
||||||
|
@ -2,6 +2,8 @@ from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
|||||||
|
|
||||||
from marl_factory_grid.modules.destinations import constants as d
|
from marl_factory_grid.modules.destinations import constants as d
|
||||||
from marl_factory_grid.modules.doors import constants as do
|
from marl_factory_grid.modules.doors import constants as do
|
||||||
|
from marl_factory_grid.environment import constants as c
|
||||||
|
|
||||||
|
|
||||||
future_planning = 7
|
future_planning = 7
|
||||||
|
|
||||||
@ -13,6 +15,7 @@ class TSPTargetAgent(TSPBaseAgent):
|
|||||||
Initializes a TSPTargetAgent that aims to reach destinations.
|
Initializes a TSPTargetAgent that aims to reach destinations.
|
||||||
"""
|
"""
|
||||||
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
||||||
|
self.fallback_action = c.NOOP
|
||||||
|
|
||||||
def _handle_doors(self, state):
|
def _handle_doors(self, state):
|
||||||
"""
|
"""
|
||||||
|
@ -12,7 +12,7 @@ if __name__ == '__main__':
|
|||||||
render = True
|
render = True
|
||||||
|
|
||||||
# Path to config File
|
# Path to config File
|
||||||
path = Path('marl_factory_grid/configs/test_config.yaml')
|
path = Path('marl_factory_grid/configs/simple_crossing.yaml')
|
||||||
|
|
||||||
# Env Init
|
# Env Init
|
||||||
factory = Factory(path)
|
factory = Factory(path)
|
||||||
@ -23,8 +23,8 @@ if __name__ == '__main__':
|
|||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
action_spaces = factory.action_space
|
action_spaces = factory.action_space
|
||||||
agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
|
# agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
|
||||||
# agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
|
agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
|
||||||
# agents = [TSPTargetAgent(factory, 0)]
|
# agents = [TSPTargetAgent(factory, 0)]
|
||||||
while not done:
|
while not done:
|
||||||
a = [x.predict() for x in agents]
|
a = [x.predict() for x in agents]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user