mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-05-23 07:16:44 +02:00
Render all spawnpoints that are matched with a target dirt pile + Fixed arrow placement
This commit is contained in:
parent
defbaf6f93
commit
98113ea849
@ -1,3 +1,5 @@
|
|||||||
|
import ast
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -147,25 +149,44 @@ def plot_action_maps(factory, agents, result_path):
|
|||||||
target_dirt_pos = factory.state.entities[d.DIRT][action_map_index].pos
|
target_dirt_pos = factory.state.entities[d.DIRT][action_map_index].pos
|
||||||
action_entities.append(
|
action_entities.append(
|
||||||
RenderEntity(name='target_dirt', probability=0, pos=swap_coordinates(target_dirt_pos)))
|
RenderEntity(name='target_dirt', probability=0, pos=swap_coordinates(target_dirt_pos)))
|
||||||
action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(
|
|
||||||
factory.state.agent_spawn_positions[agent_index])))
|
|
||||||
|
|
||||||
|
# Render all spawnpoints assigned to current target dirt pile
|
||||||
|
spawnpoints = list(factory.state.agents_conf.values())[agent_index]['positions']
|
||||||
|
all_target_dirts = []
|
||||||
|
if 'DirtPiles' in factory.conf['Entities']:
|
||||||
|
tuples = ast.literal_eval(factory.conf['Entities']['DirtPiles']['coords_or_quantity'])
|
||||||
|
for t in tuples:
|
||||||
|
all_target_dirts.append(t)
|
||||||
|
assigned_spawn_positions = []
|
||||||
|
for j in range(len(spawnpoints) // len(all_target_dirts)):
|
||||||
|
assigned_spawn_positions.append(spawnpoints[j * len(all_target_dirts) + all_target_dirts.index(target_dirt_pos)])
|
||||||
|
for spawn_pos in assigned_spawn_positions:
|
||||||
|
action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(spawn_pos)))
|
||||||
|
|
||||||
|
render_arrows = []
|
||||||
for position, probabilities in probabilities_map.items():
|
for position, probabilities in probabilities_map.items():
|
||||||
if position not in wall_positions:
|
if position not in wall_positions:
|
||||||
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
||||||
sorted_indices = sorted(range(len(probabilities)), key=lambda i: -probabilities[i])
|
sorted_indices = np.argsort(np.argsort(-probabilities))
|
||||||
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
|
colors = ['green_arrow', 'yellow_arrow', 'red_arrow', 'grey_arrow']
|
||||||
|
render_arrows.append([])
|
||||||
for rank, direction_index in enumerate(sorted_indices):
|
for rank, direction_index in enumerate(sorted_indices):
|
||||||
action = directions[direction_index]
|
action = directions[direction_index]
|
||||||
probability = probabilities[direction_index]
|
probability = probabilities[rank]
|
||||||
arrow_color = colors[rank]
|
arrow_color = colors[direction_index]
|
||||||
|
render_arrows[-1].append((probability, arrow_color, position))
|
||||||
|
|
||||||
|
# Swap west and east
|
||||||
|
for l in render_arrows:
|
||||||
|
l[1], l[3] = l[3], l[1]
|
||||||
|
for l in render_arrows:
|
||||||
|
for rank, (probability, arrow_color, position) in enumerate(l):
|
||||||
if probability > 0:
|
if probability > 0:
|
||||||
action_entity = RenderEntity(
|
action_entity = RenderEntity(
|
||||||
name=arrow_color,
|
name=arrow_color,
|
||||||
pos=position,
|
pos=position,
|
||||||
probability=probability,
|
probability=probability,
|
||||||
rotation=direction_index * 90
|
rotation=rank * 90
|
||||||
)
|
)
|
||||||
action_entities.append(action_entity)
|
action_entities.append(action_entity)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user