Reset tsp route caching

This commit is contained in:
Julian Schönberger
2024-05-24 16:18:50 +02:00
parent 98113ea849
commit 33e40deecf
2 changed files with 43 additions and 85 deletions

View File

@ -33,11 +33,9 @@ class TSPBaseAgent(ABC):
self.local_optimization = True
self._env = state
self.state = self._env.state[c.AGENT][agent_i]
self.spawn_position = np.array(self.state.pos)
self._position_graph = self.generate_pos_graph()
self._static_route = None
self.cached_route = None
self.fallback_action = None
self.action_list = []
@abstractmethod
@ -50,46 +48,6 @@ class TSPBaseAgent(ABC):
"""
return 0
def calculate_tsp_route(self, target_identifier):
"""
Calculate the TSP route to reach a target.
:param target_identifier: Identifier of the target entity
:type target_identifier: str
:return: TSP route
:rtype: List[int]
"""
target_positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
# if there are cached routes, search for one matching the current and target position
if self._env.state.route_cache and (
route := self._env.state.get_cached_route(self.state.pos, target_positions)) is not None:
# print(f"Retrieved cached route: {route}")
return route
# if none are found, calculate tsp route and cache it
else:
start_time = time.time()
if self.local_optimization:
nodes = \
[self.state.pos] + \
[x for x in target_positions if max(abs(np.subtract(x, self.state.pos))) < 3]
try:
while len(nodes) < 7:
nodes += [next(x for x in target_positions if x not in nodes)]
except StopIteration:
nodes = [self.state.pos] + target_positions
else:
nodes = [self.state.pos] + target_positions
route = tsp.traveling_salesman_problem(self._position_graph,
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
duration = time.time() - start_time
print("TSP calculation took {:.2f} seconds to execute".format(duration))
self._env.state.cache_route(route)
return route
def _use_door_or_move(self, door, target):
"""
Helper method to decide whether to use a door or move towards a target.
@ -108,6 +66,47 @@ class TSPBaseAgent(ABC):
action = self._predict_move(target)
return action
def calculate_tsp_route(self, target_identifier):
"""
Calculate the TSP route to reach a target.
:param target_identifier: Identifier of the target entity
:type target_identifier: str
:return: TSP route
:rtype: List[int]
"""
start_time = time.time()
if self.cached_route is not None:
print(f" Used cached route: {self.cached_route}")
return copy.deepcopy(self.cached_route)
else:
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
if self.local_optimization:
nodes = \
[self.state.pos] + \
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
try:
while len(nodes) < 7:
nodes += [next(x for x in positions if x not in nodes)]
except StopIteration:
nodes = [self.state.pos] + positions
else:
nodes = [self.state.pos] + positions
route = tsp.traveling_salesman_problem(self._position_graph,
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
self.cached_route = copy.deepcopy(route)
print(f"Cached route: {self.cached_route}")
end_time = time.time()
duration = end_time - start_time
print("TSP calculation took {:.2f} seconds to execute".format(duration))
return route
def _door_is_close(self, state):
"""
Check if a door is close to the agent's position.
@ -173,11 +172,8 @@ class TSPBaseAgent(ABC):
action = next(action for action, pos_diff in MOVEMAP.items() if
np.all(diff == pos_diff) and action in allowed_directions)
except StopIteration:
print(f"No valid action found for pos diff: {diff}. Using fallback action: {self.fallback_action}.")
if self.fallback_action and any(self.fallback_action == action.name for action in self.state.actions):
action = self.fallback_action
else:
action = choice(self.state.actions).name
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
action = choice(self.state.actions).name
else:
action = choice(self.state.actions).name
# noinspection PyUnboundLocalVariable

View File

@ -1,4 +1,3 @@
import copy
from itertools import islice
from typing import List, Tuple
@ -117,7 +116,6 @@ class Gamestate(object):
self.rng = np.random.default_rng(env_seed)
self.rules = StepRules(*rules)
self._floortile_graph = None
self.route_cache = []
self.tests = StepTests(*tests)
# Pointer that defines current spawn points of agents
@ -322,42 +320,6 @@ class Gamestate(object):
# json_file.seek(0)
# json.dump(existing_content, json_file, indent=4)
def cache_route(self, route):
"""
Save routes in env-level cache so agents can access it.
:param route: The route to be saved
"""
self.route_cache.append(copy.deepcopy(route))
# print(f"Cached route: {route}")
def get_cached_route(self, current_pos, target_positions, route_cutting=False):
"""
Use a cached route if it includes the current position and a target
:param current_pos: The agent's current position and thus the first position of possibly cached routes
:param target_positions: The positions of targets the agent wants to visit
:param route_cutting: if true, cuts found routes to end at target. False allows target agents to loop.
:returns: A cached route from the agent's position to the first target if it exists
"""
if not self.route_cache:
return None
for route in self.route_cache:
if current_pos in route:
targets = [target for target in target_positions if target in route]
if targets:
first_target = targets[0]
index_start = route.index(current_pos)
if route_cutting:
index_end = route.index(first_target) + 1
return copy.deepcopy(route[index_start:index_end])
else:
return copy.deepcopy(route[index_start:])
return None
class StepTests:
def __init__(self, *args):