From fe5a97a4136b29ad1765712faaffd21647ab7908 Mon Sep 17 00:00:00 2001 From: Chanumask Date: Thu, 4 Apr 2024 12:48:14 +0200 Subject: [PATCH 1/2] added allowed direction check for predict move --- .../algorithms/static/TSP_base_agent.py | 12 ++++++++---- marl_factory_grid/configs/simple_crossing.yaml | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py index aa2d966..1eff678 100644 --- a/marl_factory_grid/algorithms/static/TSP_base_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py @@ -135,18 +135,22 @@ class TSPBaseAgent(ABC): pass next_pos = self._static_route.pop(0) while next_pos == self.state.pos: - next_pos = self._static_route.pop(0) + if self._static_route: + next_pos = self._static_route.pop(0) else: if not self._static_route: self._static_route = self.calculate_tsp_route(target_identifier)[:7] next_pos = self._static_route.pop(0) while next_pos == self.state.pos: - next_pos = self._static_route.pop(0) - + if self._static_route: + next_pos = self._static_route.pop(0) diff = np.subtract(next_pos, self.state.pos) # Retrieve action based on the pos dif (like in: What do I have to do to get there?) try: - action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) + allowed_directions = [action.name for action in self.state.actions if + action.name in ['north', 'east', 'south', 'west', 'north_east', 'south_east', + 'south_west', 'north_west']] + action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff) and action in allowed_directions) except StopIteration: print(f"No valid action found for pos diff: {diff}. Using fallback action.") action = choice(self.state.actions).name diff --git a/marl_factory_grid/configs/simple_crossing.yaml b/marl_factory_grid/configs/simple_crossing.yaml index bcc0cc8..4d336bb 100644 --- a/marl_factory_grid/configs/simple_crossing.yaml +++ b/marl_factory_grid/configs/simple_crossing.yaml @@ -13,7 +13,7 @@ Agents: Agent_horizontal: Actions: - Noop - - Move8 + - Move4 Observations: - Walls - Other @@ -27,7 +27,7 @@ Agents: Agent_vertical: Actions: - Noop - - Move8 + - Move4 Observations: - Walls - Other From 54d4e1ecb5f3abb0a3ed693ac20617c08a069333 Mon Sep 17 00:00:00 2001 From: Chanumask Date: Wed, 17 Apr 2024 15:28:10 +0200 Subject: [PATCH 2/2] added simple route caching and fixed move 4 point graph in tspbaseagent --- .../algorithms/static/TSP_base_agent.py | 70 +++++++++++++++---- .../configs/simple_crossing.yaml | 4 +- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/marl_factory_grid/algorithms/static/TSP_base_agent.py b/marl_factory_grid/algorithms/static/TSP_base_agent.py index 1eff678..22ec04e 100644 --- a/marl_factory_grid/algorithms/static/TSP_base_agent.py +++ b/marl_factory_grid/algorithms/static/TSP_base_agent.py @@ -3,6 +3,8 @@ from random import choice import numpy as np from networkx.algorithms.approximation import traveling_salesman as tsp +import time +import copy from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.modules.doors import constants as do @@ -31,8 +33,9 @@ class TSPBaseAgent(ABC): self.local_optimization = True self._env = state self.state = self._env.state[c.AGENT][agent_i] - self._position_graph = points_to_graph(self._env.state.entities.floorlist) + self._position_graph = self.generate_pos_graph() self._static_route = None + self.cached_route = None @abstractmethod def predict(self, *_, **__) -> int: @@ -72,21 +75,35 @@ class TSPBaseAgent(ABC): :return: TSP route :rtype: List[int] """ - positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS] - if self.local_optimization: - nodes = \ - [self.state.pos] + \ - [x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3] - try: - while len(nodes) < 7: - nodes += [next(x for x in positions if x not in nodes)] - except StopIteration: - nodes = [self.state.pos] + positions + start_time = time.time() + + if self.cached_route is not None: + print(f" Used cached route: {self.cached_route}") + return copy.deepcopy(self.cached_route) else: - nodes = [self.state.pos] + positions - route = tsp.traveling_salesman_problem(self._position_graph, - nodes=nodes, cycle=True, method=tsp.greedy_tsp) + positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS] + if self.local_optimization: + nodes = \ + [self.state.pos] + \ + [x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3] + try: + while len(nodes) < 7: + nodes += [next(x for x in positions if x not in nodes)] + except StopIteration: + nodes = [self.state.pos] + positions + + else: + nodes = [self.state.pos] + positions + + route = tsp.traveling_salesman_problem(self._position_graph, + nodes=nodes, cycle=True, method=tsp.greedy_tsp) + self.cached_route = copy.deepcopy(route) + print(f"Cached route: {self.cached_route}") + + end_time = time.time() + duration = end_time - start_time + print("TSP calculation took {:.2f} seconds to execute".format(duration)) return route def _door_is_close(self, state): @@ -144,13 +161,15 @@ class TSPBaseAgent(ABC): while next_pos == self.state.pos: if self._static_route: next_pos = self._static_route.pop(0) + diff = np.subtract(next_pos, self.state.pos) # Retrieve action based on the pos dif (like in: What do I have to do to get there?) try: allowed_directions = [action.name for action in self.state.actions if action.name in ['north', 'east', 'south', 'west', 'north_east', 'south_east', 'south_west', 'north_west']] - action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff) and action in allowed_directions) + action = next(action for action, pos_diff in MOVEMAP.items() if + np.all(diff == pos_diff) and action in allowed_directions) except StopIteration: print(f"No valid action found for pos diff: {diff}. Using fallback action.") action = choice(self.state.actions).name @@ -158,3 +177,24 @@ class TSPBaseAgent(ABC): action = choice(self.state.actions).name # noinspection PyUnboundLocalVariable return action + + def generate_pos_graph(self): + """ + Generates a point graph based on the agents' allowed movement directions to be used in tsp route calculation. + + :return: A graph with nodes that are conneceted as specified by the movement actions. + :rtype: nx.Graph + """ + action_names = {action.name for action in self.state.actions} + + if {'north_east', 'south_east', 'south_west', 'north_west'}.issubset(action_names): + # print("All diagonal actions are present") + return points_to_graph(self._env.state.entities.floorlist) + + elif {'north', 'east', 'south', 'west'}.issubset(action_names): + # print("All cardinal directions are present") + return points_to_graph(self._env.state.entities.floorlist, allow_euclidean_connections=False) + + else: + print("Some actions are missing") + return points_to_graph(self._env.state.entities.floorlist) diff --git a/marl_factory_grid/configs/simple_crossing.yaml b/marl_factory_grid/configs/simple_crossing.yaml index 4d336bb..47c1b76 100644 --- a/marl_factory_grid/configs/simple_crossing.yaml +++ b/marl_factory_grid/configs/simple_crossing.yaml @@ -23,7 +23,7 @@ Agents: - (1,2) # It is okay to collide with other agents, so that # they end up on the same position - is_blocking_pos: true + is_blocking_pos: false Agent_vertical: Actions: - Noop @@ -34,7 +34,7 @@ Agents: - Destination Positions: - (2,1) - is_blocking_pos: true + is_blocking_pos: false # Other noteworthy Entitites Entities: