mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-11 23:42:40 +02:00
Merge branch 'route_plotting' into rl_plotting
This commit is contained in:
@ -33,9 +33,12 @@ class TSPBaseAgent(ABC):
|
||||
self.local_optimization = True
|
||||
self._env = state
|
||||
self.state = self._env.state[c.AGENT][agent_i]
|
||||
self.spawn_position = np.array(self.state.pos)
|
||||
self._position_graph = self.generate_pos_graph()
|
||||
self._static_route = None
|
||||
self.cached_route = None
|
||||
self.fallback_action = None
|
||||
self.action_list = []
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, *_, **__) -> int:
|
||||
@ -47,6 +50,46 @@ class TSPBaseAgent(ABC):
|
||||
"""
|
||||
return 0
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
"""
|
||||
Calculate the TSP route to reach a target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity
|
||||
:type target_identifier: str
|
||||
|
||||
:return: TSP route
|
||||
:rtype: List[int]
|
||||
"""
|
||||
target_positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
|
||||
# if there are cached routes, search for one matching the current and target position
|
||||
if self._env.state.route_cache and (
|
||||
route := self._env.state.get_cached_route(self.state.pos, target_positions)) is not None:
|
||||
# print(f"Retrieved cached route: {route}")
|
||||
return route
|
||||
# if none are found, calculate tsp route and cache it
|
||||
else:
|
||||
start_time = time.time()
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
[self.state.pos] + \
|
||||
[x for x in target_positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||
try:
|
||||
while len(nodes) < 7:
|
||||
nodes += [next(x for x in target_positions if x not in nodes)]
|
||||
except StopIteration:
|
||||
nodes = [self.state.pos] + target_positions
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + target_positions
|
||||
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
duration = time.time() - start_time
|
||||
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
self._env.state.cache_route(route)
|
||||
return route
|
||||
|
||||
def _use_door_or_move(self, door, target):
|
||||
"""
|
||||
Helper method to decide whether to use a door or move towards a target.
|
||||
@ -65,47 +108,6 @@ class TSPBaseAgent(ABC):
|
||||
action = self._predict_move(target)
|
||||
return action
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
"""
|
||||
Calculate the TSP route to reach a target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity
|
||||
:type target_identifier: str
|
||||
|
||||
:return: TSP route
|
||||
:rtype: List[int]
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if self.cached_route is not None:
|
||||
print(f" Used cached route: {self.cached_route}")
|
||||
return copy.deepcopy(self.cached_route)
|
||||
|
||||
else:
|
||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
[self.state.pos] + \
|
||||
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||
try:
|
||||
while len(nodes) < 7:
|
||||
nodes += [next(x for x in positions if x not in nodes)]
|
||||
except StopIteration:
|
||||
nodes = [self.state.pos] + positions
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + positions
|
||||
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
self.cached_route = copy.deepcopy(route)
|
||||
print(f"Cached route: {self.cached_route}")
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
return route
|
||||
|
||||
def _door_is_close(self, state):
|
||||
"""
|
||||
Check if a door is close to the agent's position.
|
||||
@ -171,8 +173,11 @@ class TSPBaseAgent(ABC):
|
||||
action = next(action for action, pos_diff in MOVEMAP.items() if
|
||||
np.all(diff == pos_diff) and action in allowed_directions)
|
||||
except StopIteration:
|
||||
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
||||
action = choice(self.state.actions).name
|
||||
print(f"No valid action found for pos diff: {diff}. Using fallback action: {self.fallback_action}.")
|
||||
if self.fallback_action and any(self.fallback_action == action.name for action in self.state.actions):
|
||||
action = self.fallback_action
|
||||
else:
|
||||
action = choice(self.state.actions).name
|
||||
else:
|
||||
action = choice(self.state.actions).name
|
||||
# noinspection PyUnboundLocalVariable
|
||||
|
@ -1,6 +1,7 @@
|
||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from marl_factory_grid.modules.clean_up import constants as di
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
future_planning = 7
|
||||
|
||||
@ -12,6 +13,7 @@ class TSPDirtAgent(TSPBaseAgent):
|
||||
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
|
||||
"""
|
||||
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
||||
self.fallback_action = c.NOOP
|
||||
|
||||
def predict(self, *_, **__):
|
||||
"""
|
||||
@ -28,6 +30,7 @@ class TSPDirtAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, di.DIRT)
|
||||
else:
|
||||
action = self._predict_move(di.DIRT)
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -3,6 +3,7 @@ import numpy as np
|
||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
future_planning = 7
|
||||
inventory_size = 3
|
||||
@ -22,6 +23,7 @@ class TSPItemAgent(TSPBaseAgent):
|
||||
"""
|
||||
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
||||
self.mode = mode
|
||||
self.fallback_action = c.NOOP
|
||||
|
||||
def predict(self, *_, **__):
|
||||
item_at_position = self._env.state[i.ITEM].by_pos(self.state.pos)
|
||||
@ -36,6 +38,7 @@ class TSPItemAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
|
||||
else:
|
||||
action = self._choose()
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -2,6 +2,8 @@ from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
from marl_factory_grid.modules.doors import constants as do
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
future_planning = 7
|
||||
|
||||
@ -13,6 +15,7 @@ class TSPTargetAgent(TSPBaseAgent):
|
||||
Initializes a TSPTargetAgent that aims to reach destinations.
|
||||
"""
|
||||
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
||||
self.fallback_action = c.NOOP
|
||||
|
||||
def _handle_doors(self, state):
|
||||
"""
|
||||
@ -35,6 +38,7 @@ class TSPTargetAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, d.DESTINATION)
|
||||
else:
|
||||
action = self._predict_move(d.DESTINATION)
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
Reference in New Issue
Block a user