mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 09:01:36 +02:00
Moved route caching to env level and removed print statements
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from itertools import islice
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -116,6 +117,7 @@ class Gamestate(object):
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*rules)
|
||||
self._floortile_graph = None
|
||||
self.route_cache = []
|
||||
self.tests = StepTests(*tests)
|
||||
|
||||
def reset(self):
|
||||
@ -316,6 +318,42 @@ class Gamestate(object):
|
||||
# json_file.seek(0)
|
||||
# json.dump(existing_content, json_file, indent=4)
|
||||
|
||||
def cache_route(self, route):
|
||||
"""
|
||||
Save routes in env-level cache so agents can access it.
|
||||
|
||||
:param route: The route to be saved
|
||||
"""
|
||||
self.route_cache.append(copy.deepcopy(route))
|
||||
# print(f"Cached route: {route}")
|
||||
|
||||
def get_cached_route(self, current_pos, target_positions, route_cutting=False):
|
||||
"""
|
||||
Use a cached route if it includes the current position and a target
|
||||
|
||||
:param current_pos: The agent's current position and thus the first position of possibly cached routes
|
||||
:param target_positions: The positions of targets the agent wants to visit
|
||||
:param route_cutting: if true, cuts found routes to end at target. False allows target agents to loop.
|
||||
|
||||
:returns: A cached route from the agent's position to the first target if it exists
|
||||
"""
|
||||
if not self.route_cache:
|
||||
return None
|
||||
|
||||
for route in self.route_cache:
|
||||
if current_pos in route:
|
||||
targets = [target for target in target_positions if target in route]
|
||||
if targets:
|
||||
first_target = targets[0]
|
||||
index_start = route.index(current_pos)
|
||||
|
||||
if route_cutting:
|
||||
index_end = route.index(first_target) + 1
|
||||
return copy.deepcopy(route[index_start:index_end])
|
||||
else:
|
||||
return copy.deepcopy(route[index_start:])
|
||||
return None
|
||||
|
||||
|
||||
class StepTests:
|
||||
def __init__(self, *args):
|
||||
|
Reference in New Issue
Block a user