mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
Merge branch 'unit_testing' into marl_refactor
This commit is contained in:
@ -3,6 +3,8 @@ from random import choice
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from networkx.algorithms.approximation import traveling_salesman as tsp
|
from networkx.algorithms.approximation import traveling_salesman as tsp
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
|
||||||
from marl_factory_grid.algorithms.static.utils import points_to_graph
|
from marl_factory_grid.algorithms.static.utils import points_to_graph
|
||||||
from marl_factory_grid.modules.doors import constants as do
|
from marl_factory_grid.modules.doors import constants as do
|
||||||
@ -31,8 +33,9 @@ class TSPBaseAgent(ABC):
|
|||||||
self.local_optimization = True
|
self.local_optimization = True
|
||||||
self._env = state
|
self._env = state
|
||||||
self.state = self._env.state[c.AGENT][agent_i]
|
self.state = self._env.state[c.AGENT][agent_i]
|
||||||
self._position_graph = points_to_graph(self._env.state.entities.floorlist)
|
self._position_graph = self.generate_pos_graph()
|
||||||
self._static_route = None
|
self._static_route = None
|
||||||
|
self.cached_route = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, *_, **__) -> int:
|
def predict(self, *_, **__) -> int:
|
||||||
@ -72,21 +75,35 @@ class TSPBaseAgent(ABC):
|
|||||||
:return: TSP route
|
:return: TSP route
|
||||||
:rtype: List[int]
|
:rtype: List[int]
|
||||||
"""
|
"""
|
||||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
start_time = time.time()
|
||||||
if self.local_optimization:
|
|
||||||
nodes = \
|
if self.cached_route is not None:
|
||||||
[self.state.pos] + \
|
print(f" Used cached route: {self.cached_route}")
|
||||||
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
return copy.deepcopy(self.cached_route)
|
||||||
try:
|
|
||||||
while len(nodes) < 7:
|
|
||||||
nodes += [next(x for x in positions if x not in nodes)]
|
|
||||||
except StopIteration:
|
|
||||||
nodes = [self.state.pos] + positions
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
nodes = [self.state.pos] + positions
|
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
if self.local_optimization:
|
||||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
nodes = \
|
||||||
|
[self.state.pos] + \
|
||||||
|
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||||
|
try:
|
||||||
|
while len(nodes) < 7:
|
||||||
|
nodes += [next(x for x in positions if x not in nodes)]
|
||||||
|
except StopIteration:
|
||||||
|
nodes = [self.state.pos] + positions
|
||||||
|
|
||||||
|
else:
|
||||||
|
nodes = [self.state.pos] + positions
|
||||||
|
|
||||||
|
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||||
|
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||||
|
self.cached_route = copy.deepcopy(route)
|
||||||
|
print(f"Cached route: {self.cached_route}")
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
duration = end_time - start_time
|
||||||
|
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||||
return route
|
return route
|
||||||
|
|
||||||
def _door_is_close(self, state):
|
def _door_is_close(self, state):
|
||||||
@ -135,18 +152,24 @@ class TSPBaseAgent(ABC):
|
|||||||
pass
|
pass
|
||||||
next_pos = self._static_route.pop(0)
|
next_pos = self._static_route.pop(0)
|
||||||
while next_pos == self.state.pos:
|
while next_pos == self.state.pos:
|
||||||
next_pos = self._static_route.pop(0)
|
if self._static_route:
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
else:
|
else:
|
||||||
if not self._static_route:
|
if not self._static_route:
|
||||||
self._static_route = self.calculate_tsp_route(target_identifier)[:7]
|
self._static_route = self.calculate_tsp_route(target_identifier)[:7]
|
||||||
next_pos = self._static_route.pop(0)
|
next_pos = self._static_route.pop(0)
|
||||||
while next_pos == self.state.pos:
|
while next_pos == self.state.pos:
|
||||||
next_pos = self._static_route.pop(0)
|
if self._static_route:
|
||||||
|
next_pos = self._static_route.pop(0)
|
||||||
|
|
||||||
diff = np.subtract(next_pos, self.state.pos)
|
diff = np.subtract(next_pos, self.state.pos)
|
||||||
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
# Retrieve action based on the pos dif (like in: What do I have to do to get there?)
|
||||||
try:
|
try:
|
||||||
action = next(action for action, pos_diff in MOVEMAP.items() if np.all(diff == pos_diff))
|
allowed_directions = [action.name for action in self.state.actions if
|
||||||
|
action.name in ['north', 'east', 'south', 'west', 'north_east', 'south_east',
|
||||||
|
'south_west', 'north_west']]
|
||||||
|
action = next(action for action, pos_diff in MOVEMAP.items() if
|
||||||
|
np.all(diff == pos_diff) and action in allowed_directions)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
||||||
action = choice(self.state.actions).name
|
action = choice(self.state.actions).name
|
||||||
@ -154,3 +177,24 @@ class TSPBaseAgent(ABC):
|
|||||||
action = choice(self.state.actions).name
|
action = choice(self.state.actions).name
|
||||||
# noinspection PyUnboundLocalVariable
|
# noinspection PyUnboundLocalVariable
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
def generate_pos_graph(self):
|
||||||
|
"""
|
||||||
|
Generates a point graph based on the agents' allowed movement directions to be used in tsp route calculation.
|
||||||
|
|
||||||
|
:return: A graph with nodes that are conneceted as specified by the movement actions.
|
||||||
|
:rtype: nx.Graph
|
||||||
|
"""
|
||||||
|
action_names = {action.name for action in self.state.actions}
|
||||||
|
|
||||||
|
if {'north_east', 'south_east', 'south_west', 'north_west'}.issubset(action_names):
|
||||||
|
# print("All diagonal actions are present")
|
||||||
|
return points_to_graph(self._env.state.entities.floorlist)
|
||||||
|
|
||||||
|
elif {'north', 'east', 'south', 'west'}.issubset(action_names):
|
||||||
|
# print("All cardinal directions are present")
|
||||||
|
return points_to_graph(self._env.state.entities.floorlist, allow_euclidean_connections=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Some actions are missing")
|
||||||
|
return points_to_graph(self._env.state.entities.floorlist)
|
||||||
|
@ -23,7 +23,7 @@ Agents:
|
|||||||
- (1,2)
|
- (1,2)
|
||||||
# It is okay to collide with other agents, so that
|
# It is okay to collide with other agents, so that
|
||||||
# they end up on the same position
|
# they end up on the same position
|
||||||
is_blocking_pos: true
|
is_blocking_pos: false
|
||||||
Agent_vertical:
|
Agent_vertical:
|
||||||
Actions:
|
Actions:
|
||||||
- Noop
|
- Noop
|
||||||
@ -34,7 +34,7 @@ Agents:
|
|||||||
- Destination
|
- Destination
|
||||||
Positions:
|
Positions:
|
||||||
- (2,1)
|
- (2,1)
|
||||||
is_blocking_pos: true
|
is_blocking_pos: false
|
||||||
|
|
||||||
# Other noteworthy Entitites
|
# Other noteworthy Entitites
|
||||||
Entities:
|
Entities:
|
||||||
|
Reference in New Issue
Block a user