Merge remote-tracking branch 'origin/rl_plotting' into marl_refactor
@ -18,6 +18,7 @@ from collections import deque
|
||||
|
||||
from marl_factory_grid.environment.actions import Noop
|
||||
from marl_factory_grid.modules import Clean, DoorUse
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_action_maps
|
||||
|
||||
|
||||
class Names:
|
||||
@ -583,8 +584,7 @@ class A2C:
|
||||
if self.cfg[nms.ENV]["save_and_log"]:
|
||||
self.create_info_maps(env, used_actions)
|
||||
self.save_agent_models()
|
||||
|
||||
|
||||
plot_action_maps(env, [self], self.results_path)
|
||||
|
||||
@torch.inference_mode(True)
|
||||
def eval_loop(self, n_episodes, render=False):
|
||||
|
@ -33,9 +33,12 @@ class TSPBaseAgent(ABC):
|
||||
self.local_optimization = True
|
||||
self._env = state
|
||||
self.state = self._env.state[c.AGENT][agent_i]
|
||||
self.spawn_position = np.array(self.state.pos)
|
||||
self._position_graph = self.generate_pos_graph()
|
||||
self._static_route = None
|
||||
self.cached_route = None
|
||||
self.fallback_action = None
|
||||
self.action_list = []
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, *_, **__) -> int:
|
||||
@ -47,6 +50,46 @@ class TSPBaseAgent(ABC):
|
||||
"""
|
||||
return 0
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
"""
|
||||
Calculate the TSP route to reach a target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity
|
||||
:type target_identifier: str
|
||||
|
||||
:return: TSP route
|
||||
:rtype: List[int]
|
||||
"""
|
||||
target_positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
|
||||
# if there are cached routes, search for one matching the current and target position
|
||||
if self._env.state.route_cache and (
|
||||
route := self._env.state.get_cached_route(self.state.pos, target_positions)) is not None:
|
||||
# print(f"Retrieved cached route: {route}")
|
||||
return route
|
||||
# if none are found, calculate tsp route and cache it
|
||||
else:
|
||||
start_time = time.time()
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
[self.state.pos] + \
|
||||
[x for x in target_positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||
try:
|
||||
while len(nodes) < 7:
|
||||
nodes += [next(x for x in target_positions if x not in nodes)]
|
||||
except StopIteration:
|
||||
nodes = [self.state.pos] + target_positions
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + target_positions
|
||||
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
duration = time.time() - start_time
|
||||
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
self._env.state.cache_route(route)
|
||||
return route
|
||||
|
||||
def _use_door_or_move(self, door, target):
|
||||
"""
|
||||
Helper method to decide whether to use a door or move towards a target.
|
||||
@ -65,47 +108,6 @@ class TSPBaseAgent(ABC):
|
||||
action = self._predict_move(target)
|
||||
return action
|
||||
|
||||
def calculate_tsp_route(self, target_identifier):
|
||||
"""
|
||||
Calculate the TSP route to reach a target.
|
||||
|
||||
:param target_identifier: Identifier of the target entity
|
||||
:type target_identifier: str
|
||||
|
||||
:return: TSP route
|
||||
:rtype: List[int]
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if self.cached_route is not None:
|
||||
print(f" Used cached route: {self.cached_route}")
|
||||
return copy.deepcopy(self.cached_route)
|
||||
|
||||
else:
|
||||
positions = [x for x in self._env.state[target_identifier].positions if x != c.VALUE_NO_POS]
|
||||
if self.local_optimization:
|
||||
nodes = \
|
||||
[self.state.pos] + \
|
||||
[x for x in positions if max(abs(np.subtract(x, self.state.pos))) < 3]
|
||||
try:
|
||||
while len(nodes) < 7:
|
||||
nodes += [next(x for x in positions if x not in nodes)]
|
||||
except StopIteration:
|
||||
nodes = [self.state.pos] + positions
|
||||
|
||||
else:
|
||||
nodes = [self.state.pos] + positions
|
||||
|
||||
route = tsp.traveling_salesman_problem(self._position_graph,
|
||||
nodes=nodes, cycle=True, method=tsp.greedy_tsp)
|
||||
self.cached_route = copy.deepcopy(route)
|
||||
print(f"Cached route: {self.cached_route}")
|
||||
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
print("TSP calculation took {:.2f} seconds to execute".format(duration))
|
||||
return route
|
||||
|
||||
def _door_is_close(self, state):
|
||||
"""
|
||||
Check if a door is close to the agent's position.
|
||||
@ -171,8 +173,11 @@ class TSPBaseAgent(ABC):
|
||||
action = next(action for action, pos_diff in MOVEMAP.items() if
|
||||
np.all(diff == pos_diff) and action in allowed_directions)
|
||||
except StopIteration:
|
||||
print(f"No valid action found for pos diff: {diff}. Using fallback action.")
|
||||
action = choice(self.state.actions).name
|
||||
print(f"No valid action found for pos diff: {diff}. Using fallback action: {self.fallback_action}.")
|
||||
if self.fallback_action and any(self.fallback_action == action.name for action in self.state.actions):
|
||||
action = self.fallback_action
|
||||
else:
|
||||
action = choice(self.state.actions).name
|
||||
else:
|
||||
action = choice(self.state.actions).name
|
||||
# noinspection PyUnboundLocalVariable
|
||||
|
@ -1,6 +1,7 @@
|
||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from marl_factory_grid.modules.clean_up import constants as di
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
future_planning = 7
|
||||
|
||||
@ -12,6 +13,7 @@ class TSPDirtAgent(TSPBaseAgent):
|
||||
Initializes a TSPDirtAgent that aims to clean dirt in the environment.
|
||||
"""
|
||||
super(TSPDirtAgent, self).__init__(*args, **kwargs)
|
||||
self.fallback_action = c.NOOP
|
||||
|
||||
def predict(self, *_, **__):
|
||||
"""
|
||||
@ -28,6 +30,7 @@ class TSPDirtAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, di.DIRT)
|
||||
else:
|
||||
action = self._predict_move(di.DIRT)
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -3,6 +3,7 @@ import numpy as np
|
||||
from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from marl_factory_grid.modules.items import constants as i
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
future_planning = 7
|
||||
inventory_size = 3
|
||||
@ -22,6 +23,7 @@ class TSPItemAgent(TSPBaseAgent):
|
||||
"""
|
||||
super(TSPItemAgent, self).__init__(*args, **kwargs)
|
||||
self.mode = mode
|
||||
self.fallback_action = c.NOOP
|
||||
|
||||
def predict(self, *_, **__):
|
||||
item_at_position = self._env.state[i.ITEM].by_pos(self.state.pos)
|
||||
@ -36,6 +38,7 @@ class TSPItemAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, i.DROP_OFF if self.mode == MODE_BRING else i.ITEM)
|
||||
else:
|
||||
action = self._choose()
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -2,6 +2,8 @@ from marl_factory_grid.algorithms.static.TSP_base_agent import TSPBaseAgent
|
||||
|
||||
from marl_factory_grid.modules.destinations import constants as d
|
||||
from marl_factory_grid.modules.doors import constants as do
|
||||
from marl_factory_grid.environment import constants as c
|
||||
|
||||
|
||||
future_planning = 7
|
||||
|
||||
@ -13,6 +15,7 @@ class TSPTargetAgent(TSPBaseAgent):
|
||||
Initializes a TSPTargetAgent that aims to reach destinations.
|
||||
"""
|
||||
super(TSPTargetAgent, self).__init__(*args, **kwargs)
|
||||
self.fallback_action = c.NOOP
|
||||
|
||||
def _handle_doors(self, state):
|
||||
"""
|
||||
@ -35,6 +38,7 @@ class TSPTargetAgent(TSPBaseAgent):
|
||||
action = self._use_door_or_move(door, d.DESTINATION)
|
||||
else:
|
||||
action = self._predict_move(d.DESTINATION)
|
||||
self.action_list.append(action)
|
||||
# Translate the action_object to an integer to have the same output as any other model
|
||||
try:
|
||||
action_obj = next(action_i for action_i, a in enumerate(self.state.actions) if a.name == action)
|
||||
|
@ -20,7 +20,7 @@ Agents:
|
||||
- Destination
|
||||
# Avaiable Spawn Positions as list
|
||||
Positions:
|
||||
- (1,2)
|
||||
- (2,1)
|
||||
# It is okay to collide with other agents, so that
|
||||
# they end up on the same position
|
||||
is_blocking_pos: false
|
||||
@ -33,7 +33,7 @@ Agents:
|
||||
- Other
|
||||
- Destination
|
||||
Positions:
|
||||
- (2,1)
|
||||
- (1,2)
|
||||
is_blocking_pos: false
|
||||
|
||||
# Other noteworthy Entitites
|
||||
@ -45,9 +45,9 @@ Entities:
|
||||
SpawnDestinationsPerAgent:
|
||||
coords_or_quantity:
|
||||
Agent_horizontal:
|
||||
- (3,2)
|
||||
Agent_vertical:
|
||||
- (2,3)
|
||||
Agent_vertical:
|
||||
- (3,2)
|
||||
# Whether you want to provide a numeric Position observation.
|
||||
# GlobalPositions:
|
||||
# normalized: false
|
||||
|
@ -1,45 +1,45 @@
|
||||
Agents:
|
||||
Clean test agent:
|
||||
Actions:
|
||||
- Noop
|
||||
- Charge
|
||||
- Clean
|
||||
- DoorUse
|
||||
- Move8
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- DirtPiles
|
||||
- Destinations
|
||||
- Doors
|
||||
- Maintainers
|
||||
Clones: 0
|
||||
Item test agent:
|
||||
Actions:
|
||||
- Noop
|
||||
- Charge
|
||||
- DestAction
|
||||
- DoorUse
|
||||
- ItemAction
|
||||
- Move8
|
||||
Observations:
|
||||
- Combined:
|
||||
- Other
|
||||
- Walls
|
||||
- GlobalPosition
|
||||
- Battery
|
||||
- ChargePods
|
||||
- Destinations
|
||||
- Doors
|
||||
- Items
|
||||
- Inventory
|
||||
- DropOffLocations
|
||||
- Maintainers
|
||||
Clones: 0
|
||||
# Clean test agent:
|
||||
# Actions:
|
||||
# - Noop
|
||||
# - Charge
|
||||
# - Clean
|
||||
# - DoorUse
|
||||
# - Move8
|
||||
# Observations:
|
||||
# - Combined:
|
||||
# - Other
|
||||
# - Walls
|
||||
# - GlobalPosition
|
||||
# - Battery
|
||||
# - ChargePods
|
||||
# - DirtPiles
|
||||
# - Destinations
|
||||
# - Doors
|
||||
# - Maintainers
|
||||
# Clones: 0
|
||||
# Item test agent:
|
||||
# Actions:
|
||||
# - Noop
|
||||
# - Charge
|
||||
# - DestAction
|
||||
# - DoorUse
|
||||
# - ItemAction
|
||||
# - Move8
|
||||
# Observations:
|
||||
# - Combined:
|
||||
# - Other
|
||||
# - Walls
|
||||
# - GlobalPosition
|
||||
# - Battery
|
||||
# - ChargePods
|
||||
# - Destinations
|
||||
# - Doors
|
||||
# - Items
|
||||
# - Inventory
|
||||
# - DropOffLocations
|
||||
# - Maintainers
|
||||
# Clones: 0
|
||||
Target test agent:
|
||||
Actions:
|
||||
- Noop
|
||||
@ -55,7 +55,7 @@ Agents:
|
||||
- Destinations
|
||||
- Doors
|
||||
- Maintainers
|
||||
Clones: 0
|
||||
Clones: 1
|
||||
|
||||
Entities:
|
||||
|
||||
@ -90,7 +90,7 @@ Entities:
|
||||
General:
|
||||
env_seed: 69
|
||||
individual_rewards: true
|
||||
level_name: large
|
||||
level_name: quadrant
|
||||
pomdp_r: 3
|
||||
verbose: false
|
||||
tests: false
|
||||
@ -115,10 +115,10 @@ Rules:
|
||||
|
||||
# Done Conditions
|
||||
DoneAtMaxStepsReached:
|
||||
max_steps: 500
|
||||
max_steps: 20
|
||||
|
||||
Tests:
|
||||
MaintainerTest: {}
|
||||
DirtAgentTest: {}
|
||||
ItemAgentTest: {}
|
||||
TargetAgentTest: {}
|
||||
# MaintainerTest: {}
|
||||
# DirtAgentTest: {}
|
||||
# ItemAgentTest: {}
|
||||
# TargetAgentTest: {}
|
||||
|
@ -186,6 +186,11 @@ class SpawnAgents(Rule):
|
||||
if isinstance(rule, marl_factory_grid.environment.rules.AgentSpawnRule):
|
||||
spawn_rule = rule.spawn_rule
|
||||
|
||||
if not hasattr(state, 'agent_spawn_positions'):
|
||||
state.agent_spawn_positions = []
|
||||
else:
|
||||
state.agent_spawn_positions.clear()
|
||||
|
||||
agents = state[c.AGENT]
|
||||
for agent_name, agent_conf in state.agents_conf.items():
|
||||
empty_positions = state.entities.empty_positions
|
||||
@ -198,11 +203,14 @@ class SpawnAgents(Rule):
|
||||
if position := self._get_position(spawn_rule, positions, empty_positions, positions_pointer):
|
||||
assert state.check_pos_validity(position), 'smth went wrong....'
|
||||
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
||||
state.agent_spawn_positions.append(position)
|
||||
elif positions:
|
||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||
f'\n{agent_conf["positions"].copy()}')
|
||||
else:
|
||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
|
||||
chosen_position = empty_positions.pop()
|
||||
agents.add_item(Agent(actions, observations, chosen_position, str_ident=agent_name, **other))
|
||||
state.agent_spawn_positions.append(chosen_position)
|
||||
return []
|
||||
|
||||
def _get_position(self, spawn_rule, positions, empty_positions, positions_pointer):
|
||||
|
@ -42,6 +42,7 @@ class LevelParser(object):
|
||||
level_array = h.one_hot_level(self._parsed_level, c.SYMBOL_WALL)
|
||||
self.level_shape = level_array.shape
|
||||
self.size = self.pomdp_r ** 2 if self.pomdp_r else np.prod(self.level_shape)
|
||||
self.walls = None
|
||||
|
||||
def get_coordinates_for_symbol(self, symbol, negate=False) -> np.ndarray:
|
||||
"""
|
||||
@ -74,6 +75,7 @@ class LevelParser(object):
|
||||
# Walls
|
||||
walls = Walls.from_coordinates(self.get_coordinates_for_symbol(c.SYMBOL_WALL), self.size)
|
||||
entities.add_items({c.WALLS: walls})
|
||||
self.walls = self.get_coordinates_for_symbol(c.SYMBOL_WALL)
|
||||
|
||||
# Agents
|
||||
entities.add_items({c.AGENT: Agents(self.size)})
|
||||
@ -90,12 +92,12 @@ class LevelParser(object):
|
||||
for symbol in symbols:
|
||||
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||
if np.any(level_array):
|
||||
# TODO: Get rid of this!
|
||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||
self.size, entity_kwargs=e_kwargs)
|
||||
else:
|
||||
raise ValueError(f'No {e_class} (Symbol: {e_class.symbol}) could be found!\n'
|
||||
f'Check your level file!')
|
||||
print(f'Warning: No {e_class.__name__} (Symbol: {symbol}) found in level file.'
|
||||
f' Initializing with empty position.')
|
||||
e = e_class.from_coordinates([], self.size, entity_kwargs=e_kwargs)
|
||||
else:
|
||||
e = e_class(self.size, **e_kwargs)
|
||||
entities.add_items({e.name: e})
|
||||
|
@ -48,7 +48,7 @@ class EnvRecorder(Wrapper):
|
||||
"""
|
||||
obs_type, obs, reward, done, info = self.env.step(actions)
|
||||
if not self.episodes or self._curr_episode in self.episodes:
|
||||
summary: dict = self.env.summarize_state()
|
||||
summary: dict = self.env.unwrapped.summarize_state()
|
||||
# summary.update(done=done)
|
||||
# summary.update({'episode': self._curr_episode})
|
||||
# TODO Protobuff Adjustments ######
|
||||
|
BIN
marl_factory_grid/utils/plotting/action_assets/cardinal.png
Normal file
After Width: | Height: | Size: 1.9 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/charge_action.png
Normal file
After Width: | Height: | Size: 1.4 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/clean_action.png
Normal file
After Width: | Height: | Size: 6.6 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/default.png
Normal file
After Width: | Height: | Size: 1.9 KiB |
After Width: | Height: | Size: 4.9 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/diagonal.png
Normal file
After Width: | Height: | Size: 5.7 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/door_action.png
Normal file
After Width: | Height: | Size: 990 B |
BIN
marl_factory_grid/utils/plotting/action_assets/green_arrow.png
Normal file
After Width: | Height: | Size: 455 B |
BIN
marl_factory_grid/utils/plotting/action_assets/grey_arrow.png
Normal file
After Width: | Height: | Size: 425 B |
After Width: | Height: | Size: 3.2 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/noop.png
Normal file
After Width: | Height: | Size: 1.4 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/red_arrow.png
Normal file
After Width: | Height: | Size: 439 B |
BIN
marl_factory_grid/utils/plotting/action_assets/spawn_pos.png
Normal file
After Width: | Height: | Size: 672 B |
BIN
marl_factory_grid/utils/plotting/action_assets/target_dirt.png
Normal file
After Width: | Height: | Size: 291 B |
BIN
marl_factory_grid/utils/plotting/action_assets/yellow_arrow.png
Normal file
After Width: | Height: | Size: 443 B |
@ -1,13 +1,21 @@
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from marl_factory_grid.utils.helpers import IGNORED_DF_COLUMNS
|
||||
from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
||||
|
||||
from marl_factory_grid.utils.renderer import Renderer
|
||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||
|
||||
from marl_factory_grid.modules.clean_up import constants as d
|
||||
|
||||
|
||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||
file_key: str = 'monitor', file_ext: str = 'pkl'):
|
||||
@ -60,3 +68,169 @@ def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, colum
|
||||
|
||||
prepare_plot(run_path.parent / f'{run_path.parent.name}_monitor_lineplot.png', df_melted, use_tex=use_tex)
|
||||
print('Plotting done.')
|
||||
|
||||
|
||||
def plot_routes(factory, agents):
|
||||
"""
|
||||
Creates a plot of the agents' actions on the level map by creating a Renderer and Render Entities that hold the
|
||||
icon that corresponds to the action. For deterministic agents, simply displays the agents path of actions while for
|
||||
RL agents that can supply an action map or action probabilities from their policy net.
|
||||
"""
|
||||
renderer = Renderer(factory.map.level_shape, custom_assets_path={
|
||||
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
|
||||
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
|
||||
'use_door': 'marl_factory_grid/utils/plotting/action_assets/door_action.png',
|
||||
'wall': 'marl_factory_grid/environment/assets/wall.png',
|
||||
'machine_action': 'marl_factory_grid/utils/plotting/action_assets/machine_action.png',
|
||||
'clean_action': 'marl_factory_grid/utils/plotting/action_assets/clean_action.png',
|
||||
'destination_action': 'marl_factory_grid/utils/plotting/action_assets/destination_action.png',
|
||||
'noop': 'marl_factory_grid/utils/plotting/action_assets/noop.png',
|
||||
'charge_action': 'marl_factory_grid/utils/plotting/action_assets/charge_action.png'})
|
||||
|
||||
wall_positions = swap_coordinates(factory.map.walls)
|
||||
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
||||
action_entities = list(wall_entities)
|
||||
|
||||
for index, agent in enumerate(agents):
|
||||
current_position = swap_coordinates(agent.spawn_position)
|
||||
|
||||
if hasattr(agent, 'action_probabilities'):
|
||||
# Handle RL agents with action probabilities
|
||||
top_actions = sorted(agent.action_probabilities.items(), key=lambda x: -x[1])[:4]
|
||||
else:
|
||||
# Handle deterministic agents by iterating through all actions in the list
|
||||
top_actions = [(action, 0) for action in agent.action_list]
|
||||
|
||||
for action, probability in top_actions:
|
||||
if action.lower() in rotation_mapping:
|
||||
base_icon, rotation = rotation_mapping[action.lower()]
|
||||
icon_name = 'cardinal' if 'diagonal' not in base_icon else 'diagonal'
|
||||
new_position = action_to_coords(current_position, action.lower())
|
||||
else:
|
||||
icon_name = action.lower()
|
||||
rotation = 0
|
||||
new_position = current_position
|
||||
|
||||
action_entity = RenderEntity(
|
||||
name=icon_name,
|
||||
pos=np.array(current_position),
|
||||
probability=probability,
|
||||
rotation=rotation
|
||||
)
|
||||
action_entities.append(action_entity)
|
||||
current_position = new_position
|
||||
|
||||
renderer.render_single_action_icons(action_entities) # move in/out loop for graph per agent or not
|
||||
|
||||
|
||||
def plot_action_maps(factory, agents, result_path):
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
assets_path = {
|
||||
'green_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'green_arrow.png'),
|
||||
'yellow_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'yellow_arrow.png'),
|
||||
'red_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'red_arrow.png'),
|
||||
'grey_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'grey_arrow.png'),
|
||||
'wall': os.path.join(base_dir, 'environment', 'assets', 'wall.png'),
|
||||
'target_dirt': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'target_dirt.png'),
|
||||
'spawn_pos': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'spawn_pos.png')
|
||||
}
|
||||
renderer = Renderer(factory.map.level_shape, cell_size=80, custom_assets_path=assets_path)
|
||||
|
||||
directions = ['north', 'east', 'south', 'west']
|
||||
wall_positions = swap_coordinates(factory.map.walls)
|
||||
|
||||
for agent_index, agent in enumerate(agents):
|
||||
if hasattr(agent, 'action_probabilities'):
|
||||
action_probabilities = unpack_action_probabilities(agent.action_probabilities)
|
||||
for action_map_index, probabilities_map in enumerate(action_probabilities[agent_index]):
|
||||
|
||||
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
||||
action_entities = list(wall_entities)
|
||||
target_dirt_pos = factory.state.entities[d.DIRT][action_map_index].pos
|
||||
action_entities.append(
|
||||
RenderEntity(name='target_dirt', probability=0, pos=swap_coordinates(target_dirt_pos)))
|
||||
action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(
|
||||
factory.state.agent_spawn_positions[agent_index])))
|
||||
|
||||
for position, probabilities in probabilities_map.items():
|
||||
if position not in wall_positions:
|
||||
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
||||
sorted_indices = sorted(range(len(probabilities)), key=lambda i: -probabilities[i])
|
||||
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
|
||||
|
||||
for rank, direction_index in enumerate(sorted_indices):
|
||||
action = directions[direction_index]
|
||||
probability = probabilities[direction_index]
|
||||
arrow_color = colors[rank]
|
||||
if probability > 0:
|
||||
action_entity = RenderEntity(
|
||||
name=arrow_color,
|
||||
pos=position,
|
||||
probability=probability,
|
||||
rotation=direction_index * 90
|
||||
)
|
||||
action_entities.append(action_entity)
|
||||
|
||||
renderer.render_multi_action_icons(action_entities, result_path)
|
||||
|
||||
|
||||
def unpack_action_probabilities(action_probabilities):
|
||||
unpacked = {}
|
||||
for agent_index, maps in action_probabilities.items():
|
||||
unpacked[agent_index] = []
|
||||
for map_index, probability_map in enumerate(maps):
|
||||
single_map = {}
|
||||
for y in range(len(probability_map)):
|
||||
for x in range(len(probability_map[y])):
|
||||
position = (x, y)
|
||||
probabilities = probability_map[y][x]
|
||||
single_map[position] = probabilities
|
||||
unpacked[agent_index].append(single_map)
|
||||
return unpacked
|
||||
|
||||
|
||||
def swap_coordinates(positions):
|
||||
"""
|
||||
Swaps x and y coordinates of single positions, lists or arrays
|
||||
"""
|
||||
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
|
||||
return positions[1], positions[0]
|
||||
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
|
||||
return positions[1], positions[0]
|
||||
else:
|
||||
return [(y, x) for x, y in positions]
|
||||
|
||||
|
||||
def action_to_coords(current_position, action):
|
||||
"""
|
||||
Calculates new coordinates based on the current position and a movement action.
|
||||
"""
|
||||
delta = direction_mapping.get(action)
|
||||
if delta is not None:
|
||||
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
|
||||
return new_position
|
||||
print(f"No valid movement action found for {action}.")
|
||||
return current_position
|
||||
|
||||
|
||||
rotation_mapping = {
|
||||
'north': ('cardinal', 0),
|
||||
'east': ('cardinal', 270),
|
||||
'south': ('cardinal', 180),
|
||||
'west': ('cardinal', 90),
|
||||
'north_east': ('diagonal', 0),
|
||||
'south_east': ('diagonal', 270),
|
||||
'south_west': ('diagonal', 180),
|
||||
'north_west': ('diagonal', 90)
|
||||
}
|
||||
|
||||
direction_mapping = {
|
||||
'north': (0, -1),
|
||||
'south': (0, 1),
|
||||
'east': (1, 0),
|
||||
'west': (-1, 0),
|
||||
'north_east': (1, -1),
|
||||
'north_west': (-1, -1),
|
||||
'south_east': (1, 1),
|
||||
'south_west': (-1, 1)
|
||||
}
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from collections import deque, defaultdict
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
@ -24,15 +25,14 @@ SCALE: str = 'scale'
|
||||
|
||||
|
||||
class Renderer:
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||
ASSETS = Path(__file__).parent.parent
|
||||
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16),
|
||||
lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9,
|
||||
grid_lines: bool = True, view_radius: int = 2):
|
||||
def __init__(self, lvl_shape: Tuple[int, int] = (16, 16), lvl_padded_shape: Union[Tuple[int, int], None] = None,
|
||||
cell_size: int = 40, fps: int = 7, factor: float = 0.9, grid_lines: bool = True, view_radius: int = 2,
|
||||
custom_assets_path=None):
|
||||
"""
|
||||
The Renderer class initializes and manages the rendering environment for the simulation,
|
||||
providing methods for preparing entities for display, loading assets, calculating visibility rectangles and
|
||||
@ -53,7 +53,6 @@ class Renderer:
|
||||
:param view_radius: Radius for agent's field of view.
|
||||
:type view_radius: int
|
||||
"""
|
||||
# TODO: Custom_assets paths
|
||||
self.grid_h, self.grid_w = lvl_shape
|
||||
self.lvl_padded_shape = lvl_padded_shape if lvl_padded_shape is not None else lvl_shape
|
||||
self.cell_size = cell_size
|
||||
@ -61,17 +60,18 @@ class Renderer:
|
||||
self.grid_lines = grid_lines
|
||||
self.view_radius = view_radius
|
||||
pygame.init()
|
||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
||||
self.screen_size = (self.grid_w * cell_size, self.grid_h * cell_size)
|
||||
self.screen = pygame.display.set_mode(self.screen_size)
|
||||
self.clock = pygame.time.Clock()
|
||||
assets = list(self.ASSETS.rglob('*.png'))
|
||||
self.assets = {path.stem: self.load_asset(str(path), factor) for path in assets}
|
||||
self.custom_assets_path = custom_assets_path
|
||||
self.assets = self.load_assets(custom_assets_path)
|
||||
self.save_counter = 1
|
||||
self.fill_bg()
|
||||
|
||||
now = time.time()
|
||||
# now = time.time()
|
||||
self.font = pygame.font.Font(None, 20)
|
||||
self.font.set_bold(True)
|
||||
print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
# print('Loading System font with pygame.font.Font took', time.time() - now)
|
||||
|
||||
def fill_bg(self):
|
||||
"""
|
||||
@ -100,22 +100,44 @@ class Renderer:
|
||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||
|
||||
r, c = entity.pos
|
||||
r, c = r - offset_r, c-offset_c
|
||||
r, c = r - offset_r, c - offset_c
|
||||
|
||||
img = self.assets[entity.name.lower()]
|
||||
if entity.value_operation == OPACITY:
|
||||
img.set_alpha(255*entity.value)
|
||||
img.set_alpha(255 * entity.value)
|
||||
elif entity.value_operation == SCALE:
|
||||
re = img.get_rect()
|
||||
img = pygame.transform.smoothscale(
|
||||
img, (int(entity.value*re.width), int(entity.value*re.height))
|
||||
img, (int(entity.value * re.width), int(entity.value * re.height))
|
||||
)
|
||||
o = self.cell_size//2
|
||||
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
||||
o = self.cell_size // 2
|
||||
r_, c_ = r * self.cell_size + o, c * self.cell_size + o
|
||||
rect = img.get_rect()
|
||||
rect.centerx, rect.centery = c_, r_
|
||||
return dict(source=img, dest=rect)
|
||||
|
||||
def load_assets(self, custom_assets_path):
|
||||
"""
|
||||
Loads assets from the custom path if provided, otherwise from the default path.
|
||||
"""
|
||||
assets_directory = custom_assets_path if custom_assets_path else self.ASSETS
|
||||
assets = {}
|
||||
if isinstance(assets_directory, dict):
|
||||
for key, path in assets_directory.items():
|
||||
asset = self.load_asset(path)
|
||||
if asset is not None:
|
||||
assets[key] = asset
|
||||
else:
|
||||
print(f"Warning: Asset for key '{key}' is missing and was not loaded.")
|
||||
else:
|
||||
for path in Path(assets_directory).rglob('*.png'):
|
||||
asset = self.load_asset(str(path))
|
||||
if asset is not None:
|
||||
assets[path.stem] = asset
|
||||
else:
|
||||
print(f"Warning: Asset '{path.stem}' is missing and was not loaded.")
|
||||
return assets
|
||||
|
||||
def load_asset(self, path, factor=1.0):
|
||||
"""
|
||||
Loads and resizes an asset from the specified path.
|
||||
@ -126,10 +148,28 @@ class Renderer:
|
||||
:type factor: float
|
||||
:return: Resized asset.
|
||||
"""
|
||||
s = int(factor*self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
try:
|
||||
s = int(factor * self.cell_size)
|
||||
asset = pygame.image.load(path).convert_alpha()
|
||||
asset = pygame.transform.smoothscale(asset, (s, s))
|
||||
return asset
|
||||
except pygame.error as e:
|
||||
print(f"Failed to load asset {path}: {e}")
|
||||
return self.load_default_asset()
|
||||
|
||||
def load_default_asset(self, factor=1.0):
|
||||
"""
|
||||
Loads a default asset to be used when specific assets fail to load.
|
||||
"""
|
||||
default_path = 'marl_factory_grid/utils/plotting/action_assets/default.png'
|
||||
try:
|
||||
s = int(factor * self.cell_size)
|
||||
default_asset = pygame.image.load(default_path).convert_alpha()
|
||||
default_asset = pygame.transform.smoothscale(default_asset, (s, s))
|
||||
return default_asset
|
||||
except pygame.error as e:
|
||||
print(f"Failed to load default asset: {e}")
|
||||
return None
|
||||
|
||||
def visibility_rects(self, bp, view):
|
||||
"""
|
||||
@ -143,13 +183,13 @@ class Renderer:
|
||||
:rtype: List[dict]
|
||||
"""
|
||||
rects = []
|
||||
for i, j in product(range(-self.view_radius, self.view_radius+1),
|
||||
range(-self.view_radius, self.view_radius+1)):
|
||||
for i, j in product(range(-self.view_radius, self.view_radius + 1),
|
||||
range(-self.view_radius, self.view_radius + 1)):
|
||||
if view is not None:
|
||||
if bool(view[self.view_radius+j, self.view_radius+i]):
|
||||
if bool(view[self.view_radius + j, self.view_radius + i]):
|
||||
visibility_rect = bp['dest'].copy()
|
||||
visibility_rect.centerx += i*self.cell_size
|
||||
visibility_rect.centery += j*self.cell_size
|
||||
visibility_rect.centerx += i * self.cell_size
|
||||
visibility_rect.centery += j * self.cell_size
|
||||
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
||||
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
|
||||
shape_surf.set_alpha(64)
|
||||
@ -183,7 +223,7 @@ class Renderer:
|
||||
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
|
||||
)
|
||||
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
|
||||
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
|
||||
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0] - .07 * self.cell_size,
|
||||
agent_blit['dest'].center[1]))
|
||||
blits += [agent_blit, state_blit, text_blit]
|
||||
|
||||
@ -201,9 +241,118 @@ class Renderer:
|
||||
return np.transpose(rgb_obs, (2, 0, 1))
|
||||
# return torch.from_numpy(rgb_obs).permute(2, 0, 1)
|
||||
|
||||
def render_single_action_icons(self, action_entities):
|
||||
"""
|
||||
Renders action icons based on the entities' specified actions' name, position, rotation and probability.
|
||||
Renders probabilities unequal 0.
|
||||
|
||||
:param action_entities: List of entities representing actions.
|
||||
:type action_entities: List[RenderEntity]
|
||||
"""
|
||||
self.fill_bg()
|
||||
|
||||
for action_entity in action_entities:
|
||||
if not isinstance(action_entity.pos, np.ndarray) or action_entity.pos.ndim != 1:
|
||||
print(f"Invalid position format for entity: {action_entity.pos}")
|
||||
continue
|
||||
|
||||
# Load and potentially rotate the icon based on action name
|
||||
img = self.assets[action_entity.name.lower()]
|
||||
if img is None:
|
||||
print(f"Error: No asset available for '{action_entity.name}'. Skipping rendering this entity.")
|
||||
continue
|
||||
if hasattr(action_entity, 'rotation'):
|
||||
img = pygame.transform.rotate(img, action_entity.rotation)
|
||||
|
||||
# Blit the icon image
|
||||
img_rect = img.get_rect(center=(action_entity.pos[0] * self.cell_size + self.cell_size // 2,
|
||||
action_entity.pos[1] * self.cell_size + self.cell_size // 2))
|
||||
self.screen.blit(img, img_rect)
|
||||
|
||||
# Render the probability next to the icon if it exists
|
||||
if hasattr(action_entity, 'probability') and action_entity.probability != 0:
|
||||
prob_text = self.font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
|
||||
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
|
||||
self.screen.blit(prob_text, prob_text_rect)
|
||||
|
||||
pygame.display.flip() # Update the display with all new blits
|
||||
self.save_screen("route_graph")
|
||||
|
||||
def render_multi_action_icons(self, action_entities, result_path):
|
||||
"""
|
||||
Renders multiple action icons at the same position without overlap and arranges them based on direction, except
|
||||
for walls, spawn and target positions, which cover the entire grid cell.
|
||||
"""
|
||||
self.fill_bg()
|
||||
font = pygame.font.Font(None, 20)
|
||||
|
||||
# prepare position dict to iterate over
|
||||
position_dict = defaultdict(list)
|
||||
for entity in action_entities:
|
||||
position_dict[tuple(entity.pos)].append(entity)
|
||||
|
||||
for position, entities in position_dict.items():
|
||||
entity_size = self.cell_size // 2
|
||||
entities.sort(key=lambda x: x.rotation)
|
||||
|
||||
for entity in entities:
|
||||
img = self.assets[entity.name.lower()]
|
||||
if img is None:
|
||||
print(f"Error: No asset available for '{entity.name}'. Skipping rendering this entity.")
|
||||
continue
|
||||
|
||||
# Check if the entity is a wall and adjust the size and position accordingly
|
||||
if entity.name in ['wall', 'target_dirt', 'spawn_pos']:
|
||||
img = pygame.transform.scale(img, (self.cell_size, self.cell_size))
|
||||
img_rect = img.get_rect(center=(position[0] * self.cell_size + self.cell_size // 2,
|
||||
position[1] * self.cell_size + self.cell_size // 2))
|
||||
else:
|
||||
# Define offsets for each direction based on a quadrant layout within the cell
|
||||
offsets = {
|
||||
0: (0, -entity_size // 2), # North
|
||||
90: (-entity_size // 2, 0), # West
|
||||
180: (0, entity_size // 2), # South
|
||||
270: (entity_size // 2, 0) # East
|
||||
}
|
||||
img = pygame.transform.scale(img, (int(entity_size), entity_size))
|
||||
offset = offsets.get(entity.rotation, (0, 0))
|
||||
img_rect = img.get_rect(center=(
|
||||
position[0] * self.cell_size + self.cell_size // 2 + offset[0],
|
||||
position[1] * self.cell_size + self.cell_size // 2 + offset[1]
|
||||
))
|
||||
|
||||
img = pygame.transform.rotate(img, entity.rotation)
|
||||
self.screen.blit(img, img_rect)
|
||||
|
||||
# Render the probability next to the icon if it exists and is non-zero
|
||||
if entity.probability > 0 and entity.name != 'wall':
|
||||
formatted_probability = f"{entity.probability * 100:.2f}"
|
||||
prob_text = font.render(formatted_probability, True, (0, 0, 0))
|
||||
prob_text_rect = prob_text.get_rect(center=img_rect.center) # Center text on the arrow
|
||||
self.screen.blit(prob_text, prob_text_rect)
|
||||
|
||||
pygame.display.flip()
|
||||
self.save_screen("multi_action_graph", result_path)
|
||||
|
||||
def save_screen(self, filename, result_path):
|
||||
"""
|
||||
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
|
||||
:param filename: The base filename where to save the image.
|
||||
:param result_path: path to out folder
|
||||
"""
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
out_dir = os.path.join(base_dir, 'study_out', result_path)
|
||||
if not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir)
|
||||
unique_filename = f"{filename}_agent_{self.save_counter}.png"
|
||||
self.save_counter += 1
|
||||
full_path = os.path.join(out_dir, unique_filename)
|
||||
pygame.image.save(self.screen, full_path)
|
||||
print(f"Image saved as {unique_filename}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
renderer = Renderer(fps=2, cell_size=40)
|
||||
renderer = Renderer(cell_size=40, fps=2)
|
||||
for pos_i in range(15):
|
||||
entity_1 = RenderEntity('agent_collision', [5, pos_i], 1, 'idle', 'idle')
|
||||
renderer.render([entity_1])
|
||||
|
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from itertools import islice
|
||||
from typing import List, Tuple
|
||||
|
||||
@ -116,6 +117,7 @@ class Gamestate(object):
|
||||
self.rng = np.random.default_rng(env_seed)
|
||||
self.rules = StepRules(*rules)
|
||||
self._floortile_graph = None
|
||||
self.route_cache = []
|
||||
self.tests = StepTests(*tests)
|
||||
|
||||
# Pointer that defines current spawn points of agents
|
||||
@ -320,6 +322,42 @@ class Gamestate(object):
|
||||
# json_file.seek(0)
|
||||
# json.dump(existing_content, json_file, indent=4)
|
||||
|
||||
def cache_route(self, route):
|
||||
"""
|
||||
Save routes in env-level cache so agents can access it.
|
||||
|
||||
:param route: The route to be saved
|
||||
"""
|
||||
self.route_cache.append(copy.deepcopy(route))
|
||||
# print(f"Cached route: {route}")
|
||||
|
||||
def get_cached_route(self, current_pos, target_positions, route_cutting=False):
|
||||
"""
|
||||
Use a cached route if it includes the current position and a target
|
||||
|
||||
:param current_pos: The agent's current position and thus the first position of possibly cached routes
|
||||
:param target_positions: The positions of targets the agent wants to visit
|
||||
:param route_cutting: if true, cuts found routes to end at target. False allows target agents to loop.
|
||||
|
||||
:returns: A cached route from the agent's position to the first target if it exists
|
||||
"""
|
||||
if not self.route_cache:
|
||||
return None
|
||||
|
||||
for route in self.route_cache:
|
||||
if current_pos in route:
|
||||
targets = [target for target in target_positions if target in route]
|
||||
if targets:
|
||||
first_target = targets[0]
|
||||
index_start = route.index(current_pos)
|
||||
|
||||
if route_cutting:
|
||||
index_end = route.index(first_target) + 1
|
||||
return copy.deepcopy(route[index_start:index_end])
|
||||
else:
|
||||
return copy.deepcopy(route[index_start:])
|
||||
return None
|
||||
|
||||
|
||||
class StepTests:
|
||||
def __init__(self, *args):
|
||||
|
@ -33,6 +33,8 @@ class RenderEntity:
|
||||
id: int = 0
|
||||
aux: Any = None
|
||||
real_name: str = 'none'
|
||||
probability: float = None # Default to None if not used
|
||||
rotation: int = 0 # Default rotation if not specified
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -18,7 +18,6 @@ def single_agent_training(config_name):
|
||||
# Have consecutive episode for eval in single agent case
|
||||
train_cfg["algorithm"]["pile_all_done"] = "all"
|
||||
agent.eval_loop(10)
|
||||
print(agent.action_probabilities)
|
||||
|
||||
|
||||
def single_agent_eval(config_name, run):
|
||||
|
13
test_run.py
@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
from tqdm import trange
|
||||
|
||||
@ -7,12 +8,17 @@ from marl_factory_grid.algorithms.static.TSP_item_agent import TSPItemAgent
|
||||
from marl_factory_grid.algorithms.static.TSP_target_agent import TSPTargetAgent
|
||||
from marl_factory_grid.environment.factory import Factory
|
||||
|
||||
from marl_factory_grid.utils.plotting.plot_single_runs import plot_routes, plot_action_maps
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Render at each step?
|
||||
|
||||
run_path = Path('study_out')
|
||||
render = True
|
||||
monitor = True
|
||||
record = True
|
||||
|
||||
# Path to config File
|
||||
path = Path('marl_factory_grid/configs/simple_crossing.yaml')
|
||||
path = Path('marl_factory_grid/configs/test_config.yaml')
|
||||
|
||||
# Env Init
|
||||
factory = Factory(path)
|
||||
@ -25,6 +31,7 @@ if __name__ == '__main__':
|
||||
action_spaces = factory.action_space
|
||||
# agents = [TSPDirtAgent(factory, 0), TSPItemAgent(factory, 1), TSPTargetAgent(factory, 2)]
|
||||
agents = [TSPTargetAgent(factory, 0), TSPTargetAgent(factory, 1)]
|
||||
# agents = [TSPTargetAgent(factory, 0)]
|
||||
while not done:
|
||||
a = [x.predict() for x in agents]
|
||||
obs_type, _, _, done, info = factory.step(a)
|
||||
@ -33,3 +40,5 @@ if __name__ == '__main__':
|
||||
if done:
|
||||
print(f'Episode {episode} done...')
|
||||
break
|
||||
|
||||
plot_routes(factory, agents)
|
||||
|