mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-04 16:41:36 +02:00
Moved route caching to env level and removed print statements
This commit is contained in:
@ -47,6 +47,46 @@ class TSPBaseAgent(ABC):
|
|||||||
"""
|
"""
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def calculate_tsp_route(self, target_identifier):
|
||||||
|
"""
|
||||||
|
Calculate the TSP route to reach a target.
|
||||||
|
|
||||||
|
:param target_identifier: Identifier of the target entity
|
||||||
|
:type target_identifier: str
|
||||||
|
|
||||||
|
:return: TSP route
|
||||||
|
:rtype: List[int]
|
||||||
|
"""
|
||||||
|
target_positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||||
|
|
||||||
|
# if there are cached routes, search for one matching the current and target position
|
||||||
|
if self._env.state.route_cache and (
|
||||||
|
route := self._env.state.get_cached_route(self.state.pos, target_positions)) is not None:
|
||||||
|
# print(f"Retrieved cached route: {route}")
|
||||||
|
return route
|
||||||
|
# if none are found, calculate tsp route and cache it
|
||||||
|
else:
|
||||||
|
start_time = time.time()
|
||||||
|
if self.local_optimization:
|
||||||
|
nodes = \
|
||||||
|
[self.state.pos] + \
|
||||||
|
[x for x in target_positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||||
|
try:
|
||||||
|
while len(nodes) < 7:
|
||||||
|
nodes += [next(x for x in target_positions if x not in nodes)]
|
||||||
|
except StopIteration:
|
||||||
|
nodes = [self.state.pos] + target_positions
|
||||||
|
|
||||||
|
else:
|
||||||
|
nodes = [self.state.pos] + target_positions
|
||||||
|
|
||||||
|
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||||
|
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||||
|
self._env.state.cache_route(route)
|
||||||
|
return route
|
||||||
|
|
||||||
def _use_door_or_move(self, door, target):
|
def _use_door_or_move(self, door, target):
|
||||||
"""
|
"""
|
||||||
Helper method to decide whether to use a door or move towards a target.
|
Helper method to decide whether to use a door or move towards a target.
|
||||||
@ -65,47 +105,6 @@ class TSPBaseAgent(ABC):
|
|||||||
action = self._predict_move(target)
|
action = self._predict_move(target)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def calculate_tsp_route(self, target_identifier):
|
|
||||||
"""
|
|
||||||
Calculate the TSP route to reach a target.
|
|
||||||
|
|
||||||
:param target_identifier: Identifier of the target entity
|
|
||||||
:type target_identifier: str
|
|
||||||
|
|
||||||
:return: TSP route
|
|
||||||
:rtype: List[int]
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
if self.cached_route is not None:
|
|
||||||
print(f" Used cached route: {self.cached_route}")
|
|
||||||
return copy.deepcopy(self.cached_route)
|
|
||||||
|
|
||||||
else:
|
|
||||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
|
||||||
if self.local_optimization:
|
|
||||||
nodes = \
|
|
||||||
[self.state.pos] + \
|
|
||||||
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
|
||||||
try:
|
|
||||||
while len(nodes) < 7:
|
|
||||||
nodes += [next(x for x in positions if x not in nodes)]
|
|
||||||
except StopIteration:
|
|
||||||
nodes = [self.state.pos] + positions
|
|
||||||
|
|
||||||
else:
|
|
||||||
nodes = [self.state.pos] + positions
|
|
||||||
|
|
||||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
|
||||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
|
||||||
self.cached_route = copy.deepcopy(route)
|
|
||||||
print(f"Cached route: {self.cached_route}")
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
duration = end_time - start_time
|
|
||||||
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
|
||||||
return route
|
|
||||||
|
|
||||||
def _door_is_close(self, state):
|
def _door_is_close(self, state):
|
||||||
"""
|
"""
|
||||||
Check if a door is close to the agent's position.
|
Check if a door is close to the agent's position.
|
||||||
|
@ -68,10 +68,10 @@ class Renderer:
|
|||||||
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
|
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
|
|
||||||
now = time.time()
|
# now = time.time()
|
||||||
self.font = pygame.font.Font(None, 20)
|
self.font = pygame.font.Font(None, 20)
|
||||||
self.font.set_bold(True)
|
self.font.set_bold(True)
|
||||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
# print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||||
|
|
||||||
def fill_bg(self):
|
def fill_bg(self):
|
||||||
"""
|
"""
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
@ -116,6 +117,7 @@ class Gamestate(object):
|
|||||||
self.rng = np.random.default_rng(env_seed)
|
self.rng = np.random.default_rng(env_seed)
|
||||||
self.rules = StepRules(*rules)
|
self.rules = StepRules(*rules)
|
||||||
self._floortile_graph = None
|
self._floortile_graph = None
|
||||||
|
self.route_cache = []
|
||||||
self.tests = StepTests(*tests)
|
self.tests = StepTests(*tests)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -316,6 +318,42 @@ class Gamestate(object):
|
|||||||
# json_file.seek(0)
|
# json_file.seek(0)
|
||||||
# json.dump(existing_content, json_file, indent=4)
|
# json.dump(existing_content, json_file, indent=4)
|
||||||
|
|
||||||
|
def cache_route(self, route):
|
||||||
|
"""
|
||||||
|
Save routes in env-level cache so agents can access it.
|
||||||
|
|
||||||
|
:param route: The route to be saved
|
||||||
|
"""
|
||||||
|
self.route_cache.append(copy.deepcopy(route))
|
||||||
|
# print(f"Cached route: {route}")
|
||||||
|
|
||||||
|
def get_cached_route(self, current_pos, target_positions, route_cutting=False):
|
||||||
|
"""
|
||||||
|
Use a cached route if it includes the current position and a target
|
||||||
|
|
||||||
|
:param current_pos: The agent's current position and thus the first position of possibly cached routes
|
||||||
|
:param target_positions: The positions of targets the agent wants to visit
|
||||||
|
:param route_cutting: if true, cuts found routes to end at target. False allows target agents to loop.
|
||||||
|
|
||||||
|
:returns: A cached route from the agent's position to the first target if it exists
|
||||||
|
"""
|
||||||
|
if not self.route_cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for route in self.route_cache:
|
||||||
|
if current_pos in route:
|
||||||
|
targets = [target for target in target_positions if target in route]
|
||||||
|
if targets:
|
||||||
|
first_target = targets[0]
|
||||||
|
index_start = route.index(current_pos)
|
||||||
|
|
||||||
|
if route_cutting:
|
||||||
|
index_end = route.index(first_target) + 1
|
||||||
|
return copy.deepcopy(route[index_start:index_end])
|
||||||
|
else:
|
||||||
|
return copy.deepcopy(route[index_start:])
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class StepTests:
|
class StepTests:
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
|
@ -12,7 +12,7 @@ if __name__ == '__main__':
|
|||||||
render = True
|
render = True
|
||||||
|
|
||||||
# Path to config File
|
# Path to config File
|
||||||
path = Path('marl_factory_grid/configs/simple_crossing.yaml')
|
path = Path('marl_factory_grid/configs/test_config.yaml')
|
||||||
|
|
||||||
# Env Init
|
# Env Init
|
||||||
factory = Factory(path)
|
factory = Factory(path)
|
||||||
@ -23,8 +23,9 @@ if __name__ == '__main__':
|
|||||||
if render:
|
if render:
|
||||||
factory.render()
|
factory.render()
|
||||||
action_spaces = factory.action_space
|
action_spaces = factory.action_space
|
||||||
# agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
|
agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
|
||||||
agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
|
# agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
|
||||||
|
# agents = [TSPTargetAgent(factory, 0)]
|
||||||
while not done:
|
while not done:
|
||||||
a = [x.predict() for x in agents]
|
a = [x.predict() for x in agents]
|
||||||
obs_type, _, _, done, info = factory.step(a)
|
obs_type, _, _, done, info = factory.step(a)
|
||||||
|
Reference in New Issue
Block a user