Merge branch 'unit_testing' into marl_refactor

This commit is contained in:
Julian Schönberger
2024-04-19 09:46:46 +02:00
2 changed files with 63 additions and 19 deletions

View File

@ -3,6 +3,8 @@ from random import choice
import numpy as np import numpy as np
from networkx.algorithms.approximation import traveling_salesman as tsp from networkx.algorithms.approximation import traveling_salesman as tsp
import time
import copy
from marl_factory_grid.algorithms.static.utils import points_to_graph from marl_factory_grid.algorithms.static.utils import points_to_graph
from marl_factory_grid.modules.doors import constants as do from marl_factory_grid.modules.doors import constants as do
@ -31,8 +33,9 @@ class TSPBaseAgent(ABC):
self.local_optimization = True self.local_optimization = True
self._env = state self._env = state
self.state = self._env.state[c.AGENT][agent_i] self.state = self._env.state[c.AGENT][agent_i]
self._position_graph = points_to_graph(self._env.state.entities.floorlist) self._position_graph = self.generate_pos_graph()
self._static_route = None self._static_route = None
self.cached_route = None
@abstractmethod @abstractmethod
def predict(self, *_, **__) -> int: def predict(self, *_, **__) -> int:
@ -72,6 +75,13 @@ class TSPBaseAgent(ABC):
:return: TSP route :return: TSP route
:rtype: List[int] :rtype: List[int]
""" """
start_time = time.time()
if self.cached_route is not None:
print(f" Used cached route: {self.cached_route}")
return copy.deepcopy(self.cached_route)
else:
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS] positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
if self.local_optimization: if self.local_optimization:
nodes = \ nodes = \
@ -85,8 +95,15 @@ class TSPBaseAgent(ABC):
else: else:
nodes = [self.state.pos] + positions nodes = [self.state.pos] + positions
route = tsp.traveling_salesman_problem(self._position_graph, route = tsp.traveling_salesman_problem(self._position_graph,
nodes=nodes, cycle=True, method=tsp.greedy_tsp) nodes=nodes, cycle=True, method=tsp.greedy_tsp)
self.cached_route = copy.deepcopy(route)
print(f"Cached route: {self.cached_route}")
end_time = time.time()
duration = end_time - start_time
print("TSP calculation took {:.2f} seconds to execute".format(duration))
return route return route
def _door_is_close(self, state): def _door_is_close(self, state):
@ -135,18 +152,24 @@ class TSPBaseAgent(ABC):
pass pass
next_pos = self._static_route.pop(0) next_pos = self._static_route.pop(0)
while next_pos == self.state.pos: while next_pos == self.state.pos:
if self._static_route:
next_pos = self._static_route.pop(0) next_pos = self._static_route.pop(0)
else: else:
if not self._static_route: if not self._static_route:
self._static_route = self.calculate_tsp_route(target_identifier)[:7] self._static_route = self.calculate_tsp_route(target_identifier)[:7]
next_pos = self._static_route.pop(0) next_pos = self._static_route.pop(0)
while next_pos == self.state.pos: while next_pos == self.state.pos:
if self._static_route:
next_pos = self._static_route.pop(0) next_pos = self._static_route.pop(0)
diff = np.subtract(next_pos, self.state.pos) diff = np.subtract(next_pos, self.state.pos)
# Retrieve action based on the pos dif (like in: What do I have to do to get there?) # Retrieve action based on the pos dif (like in: What do I have to do to get there?)
try: try:
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff)) allowed_directions = [action.name for action in self.state.actions if
action.name in ['north', 'east', 'south', 'west', 'north_east', 'south_east',
'south_west', 'north_west']]
action = next(action for action, pos_diff in MOVEMAP.items() if
np.all(diff == pos_diff) and action in allowed_directions)
except StopIteration: except StopIteration:
print(f"No valid action found for pos diff: {diff}. Using fallback action.") print(f"No valid action found for pos diff: {diff}. Using fallback action.")
action = choice(self.state.actions).name action = choice(self.state.actions).name
@ -154,3 +177,24 @@ class TSPBaseAgent(ABC):
action = choice(self.state.actions).name action = choice(self.state.actions).name
# noinspection PyUnboundLocalVariable # noinspection PyUnboundLocalVariable
return action return action
def generate_pos_graph(self):
"""
Generates a point graph based on the agents' allowed movement directions to be used in tsp route calculation.
:return: A graph with nodes that are conneceted as specified by the movement actions.
:rtype: nx.Graph
"""
action_names = {action.name for action in self.state.actions}
if {'north_east', 'south_east', 'south_west', 'north_west'}.issubset(action_names):
# print("All diagonal actions are present")
return points_to_graph(self._env.state.entities.floorlist)
elif {'north', 'east', 'south', 'west'}.issubset(action_names):
# print("All cardinal directions are present")
return points_to_graph(self._env.state.entities.floorlist, allow_euclidean_connections=False)
else:
print("Some actions are missing")
return points_to_graph(self._env.state.entities.floorlist)

View File

@ -23,7 +23,7 @@ Agents:
- (1,2) - (1,2)
# It is okay to collide with other agents, so that # It is okay to collide with other agents, so that
# they end up on the same position # they end up on the same position
is_blocking_pos: true is_blocking_pos: false
Agent_vertical: Agent_vertical:
Actions: Actions:
- Noop - Noop
@ -34,7 +34,7 @@ Agents:
- Destination - Destination
Positions: Positions:
- (2,1) - (2,1)
is_blocking_pos: true is_blocking_pos: false
# Other noteworthy Entitites # Other noteworthy Entitites
Entities: Entities: