fixed coordinate mismatch in route plotting and added assets for actions
@@ -20,7 +20,7 @@ Agents:
|
|||||||
- Destination
|
- Destination
|
||||||
# Avaiable Spawn Positions as list
|
# Avaiable Spawn Positions as list
|
||||||
Positions:
|
Positions:
|
||||||
- (1,2)
|
- (2,1)
|
||||||
# It is okay to collide with other agents, so that
|
# It is okay to collide with other agents, so that
|
||||||
# they end up on the same position
|
# they end up on the same position
|
||||||
is_blocking_pos: false
|
is_blocking_pos: false
|
||||||
@@ -33,7 +33,7 @@ Agents:
|
|||||||
- Other
|
- Other
|
||||||
- Destination
|
- Destination
|
||||||
Positions:
|
Positions:
|
||||||
- (2,1)
|
- (1,2)
|
||||||
is_blocking_pos: false
|
is_blocking_pos: false
|
||||||
|
|
||||||
# Other noteworthy Entitites
|
# Other noteworthy Entitites
|
||||||
@@ -45,9 +45,9 @@ Entities:
|
|||||||
SpawnDestinationsPerAgent:
|
SpawnDestinationsPerAgent:
|
||||||
coords_or_quantity:
|
coords_or_quantity:
|
||||||
Agent_horizontal:
|
Agent_horizontal:
|
||||||
- (3,2)
|
|
||||||
Agent_vertical:
|
|
||||||
- (2,3)
|
- (2,3)
|
||||||
|
Agent_vertical:
|
||||||
|
- (3,2)
|
||||||
# Whether you want to provide a numeric Position observation.
|
# Whether you want to provide a numeric Position observation.
|
||||||
# GlobalPositions:
|
# GlobalPositions:
|
||||||
# normalized: false
|
# normalized: false
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ Rules:
|
|||||||
|
|
||||||
# Done Conditions
|
# Done Conditions
|
||||||
DoneAtMaxStepsReached:
|
DoneAtMaxStepsReached:
|
||||||
max_steps: 500
|
max_steps: 20
|
||||||
|
|
||||||
Tests:
|
Tests:
|
||||||
# MaintainerTest: {}
|
# MaintainerTest: {}
|
||||||
|
|||||||
BIN
marl_factory_grid/utils/plotting/action_assets/cardinal.png
Normal file
|
After Width: | Height: | Size: 1.9 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/charge_action.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/clean_action.png
Normal file
|
After Width: | Height: | Size: 6.6 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/default.png
Normal file
|
After Width: | Height: | Size: 1.9 KiB |
|
After Width: | Height: | Size: 4.9 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/diagonal.png
Normal file
|
After Width: | Height: | Size: 5.7 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/door_action.png
Normal file
|
After Width: | Height: | Size: 990 B |
|
After Width: | Height: | Size: 3.2 KiB |
BIN
marl_factory_grid/utils/plotting/action_assets/noop.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
@@ -71,33 +71,50 @@ rotation_mapping = {
|
|||||||
'east': ('cardinal', 270),
|
'east': ('cardinal', 270),
|
||||||
'south': ('cardinal', 180),
|
'south': ('cardinal', 180),
|
||||||
'west': ('cardinal', 90),
|
'west': ('cardinal', 90),
|
||||||
'northeast': ('diagonal', 0),
|
'north_east': ('diagonal', 0),
|
||||||
'southeast': ('diagonal', 270),
|
'south_east': ('diagonal', 270),
|
||||||
'southwest': ('diagonal', 180),
|
'south_west': ('diagonal', 180),
|
||||||
'northwest': ('diagonal', 90)
|
'north_west': ('diagonal', 90)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def swap_coordinates(positions):
|
||||||
|
if isinstance(positions, tuple) or (isinstance(positions, list) and len(positions) == 2):
|
||||||
|
# Single position, directly return swapped
|
||||||
|
return positions[1], positions[0]
|
||||||
|
elif isinstance(positions, np.ndarray) and positions.ndim == 1 and positions.shape[0] == 2:
|
||||||
|
# Single position in NumPy array
|
||||||
|
return positions[1], positions[0]
|
||||||
|
else:
|
||||||
|
# Assume it's an iterable of positions
|
||||||
|
return [(y, x) for x, y in positions]
|
||||||
|
|
||||||
|
|
||||||
def plot_routes(factory, agents):
|
def plot_routes(factory, agents):
|
||||||
renderer = Renderer(factory.map.level_shape, custom_assets_path={
|
renderer = Renderer(factory.map.level_shape, custom_assets_path={
|
||||||
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
|
'cardinal': 'marl_factory_grid/utils/plotting/action_assets/cardinal.png',
|
||||||
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
|
'diagonal': 'marl_factory_grid/utils/plotting/action_assets/diagonal.png',
|
||||||
'door': 'marl_factory_grid/utils/plotting/action_assets/door.png',
|
'use_door': 'marl_factory_grid/utils/plotting/action_assets/door_action.png',
|
||||||
'wall': 'marl_factory_grid/environment/assets/wall.png'})
|
'wall': 'marl_factory_grid/environment/assets/wall.png',
|
||||||
|
'machine_action': 'marl_factory_grid/utils/plotting/action_assets/machine_action.png',
|
||||||
|
'clean_action': 'marl_factory_grid/utils/plotting/action_assets/clean_action.png',
|
||||||
|
'destination_action': 'marl_factory_grid/utils/plotting/action_assets/destination_action.png',
|
||||||
|
'noop': 'marl_factory_grid/utils/plotting/action_assets/noop.png',
|
||||||
|
'charge_action': 'marl_factory_grid/utils/plotting/action_assets/charge_action.png'})
|
||||||
|
|
||||||
wall_positions = factory.map.walls
|
wall_positions = factory.map.walls
|
||||||
|
swapped_wall_positions = swap_coordinates(wall_positions)
|
||||||
|
|
||||||
|
action_entities = []
|
||||||
for index, agent in enumerate(agents):
|
for index, agent in enumerate(agents):
|
||||||
action_entities = []
|
|
||||||
# Add walls to the action_entities list
|
# Add walls to the action_entities list
|
||||||
for pos in wall_positions:
|
for pos in swapped_wall_positions:
|
||||||
wall_entity = RenderEntity(
|
wall_entity = RenderEntity(
|
||||||
name='wall',
|
name='wall',
|
||||||
probability=1.0,
|
probability=0,
|
||||||
pos=np.array(pos),
|
pos=np.array(pos),
|
||||||
)
|
)
|
||||||
action_entities.append(wall_entity)
|
action_entities.append(wall_entity)
|
||||||
current_position = agent.spawn_position
|
|
||||||
|
|
||||||
if hasattr(agent, 'action_probabilities'):
|
if hasattr(agent, 'action_probabilities'):
|
||||||
# Handle RL agents with action probabilities
|
# Handle RL agents with action probabilities
|
||||||
@@ -106,11 +123,18 @@ def plot_routes(factory, agents):
|
|||||||
# Handle deterministic agents by iterating through all actions in the list
|
# Handle deterministic agents by iterating through all actions in the list
|
||||||
top_actions = [(action, 1.0) for action in agent.action_list]
|
top_actions = [(action, 1.0) for action in agent.action_list]
|
||||||
|
|
||||||
|
current_position = agent.spawn_position
|
||||||
|
current_position = swap_coordinates(current_position)
|
||||||
|
|
||||||
for action, probability in top_actions:
|
for action, probability in top_actions:
|
||||||
base_icon, rotation = rotation_mapping.get(action.lower(), ('north', 0))
|
if action.lower() in rotation_mapping:
|
||||||
icon_name = base_icon
|
base_icon, rotation = rotation_mapping[action.lower()]
|
||||||
new_position = action_to_coords(current_position, action.lower())
|
icon_name = 'cardinal' if 'diagonal' not in base_icon else 'diagonal'
|
||||||
print(f"current position type and value: {type(current_position)}, {new_position}")
|
new_position = action_to_coords(current_position, action.lower())
|
||||||
|
else:
|
||||||
|
icon_name = action.lower()
|
||||||
|
rotation = 0
|
||||||
|
new_position = current_position
|
||||||
|
|
||||||
action_entity = RenderEntity(
|
action_entity = RenderEntity(
|
||||||
name=icon_name,
|
name=icon_name,
|
||||||
@@ -118,10 +142,11 @@ def plot_routes(factory, agents):
|
|||||||
probability=probability,
|
probability=probability,
|
||||||
rotation=rotation
|
rotation=rotation
|
||||||
)
|
)
|
||||||
|
# print(f"curr: {current_position}, new: {new_position}, pos: {action_entity.pos}")
|
||||||
action_entities.append(action_entity)
|
action_entities.append(action_entity)
|
||||||
current_position = new_position
|
current_position = new_position
|
||||||
|
|
||||||
renderer.render_action_icons(action_entities)
|
renderer.render_action_icons(action_entities) # move in/out loop for graph per agent or not
|
||||||
|
|
||||||
|
|
||||||
def action_to_coords(current_position, action):
|
def action_to_coords(current_position, action):
|
||||||
@@ -130,15 +155,14 @@ def action_to_coords(current_position, action):
|
|||||||
'south': (0, 1),
|
'south': (0, 1),
|
||||||
'east': (1, 0),
|
'east': (1, 0),
|
||||||
'west': (-1, 0),
|
'west': (-1, 0),
|
||||||
'northeast': (1, -1),
|
'north_east': (1, -1),
|
||||||
'northwest': (-1, -1),
|
'north_west': (-1, -1),
|
||||||
'southeast': (1, 1),
|
'south_east': (1, 1),
|
||||||
'southwest': (-1, 1)
|
'south_west': (-1, 1)
|
||||||
}
|
}
|
||||||
delta = direction_mapping.get(action)
|
delta = direction_mapping.get(action)
|
||||||
if delta is not None:
|
if delta is not None:
|
||||||
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
|
new_position = [current_position[0] + delta[0], current_position[1] + delta[1]]
|
||||||
return new_position
|
return new_position
|
||||||
print(f"No valid action found for {action}.")
|
print(f"No valid movement action found for {action}.")
|
||||||
return current_position
|
return current_position
|
||||||
|
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ SCALE: str = 'scale'
|
|||||||
|
|
||||||
|
|
||||||
class Renderer:
|
class Renderer:
|
||||||
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
BG_COLOR = (178, 190, 195) # (99, 110, 114)
|
||||||
WHITE = (223, 230, 233) # (200, 200, 200)
|
WHITE = (223, 230, 233) # (200, 200, 200)
|
||||||
AGENT_VIEW_COLOR = (9, 132, 227)
|
AGENT_VIEW_COLOR = (9, 132, 227)
|
||||||
ASSETS = Path(__file__).parent.parent
|
ASSETS = Path(__file__).parent.parent
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ class Renderer:
|
|||||||
self.grid_lines = grid_lines
|
self.grid_lines = grid_lines
|
||||||
self.view_radius = view_radius
|
self.view_radius = view_radius
|
||||||
pygame.init()
|
pygame.init()
|
||||||
self.screen_size = (self.grid_w*cell_size, self.grid_h*cell_size)
|
self.screen_size = (self.grid_w * cell_size, self.grid_h * cell_size)
|
||||||
self.screen = pygame.display.set_mode(self.screen_size)
|
self.screen = pygame.display.set_mode(self.screen_size)
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
self.custom_assets_path = custom_assets_path
|
self.custom_assets_path = custom_assets_path
|
||||||
@@ -99,18 +99,18 @@ class Renderer:
|
|||||||
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
(self.lvl_padded_shape[1] - self.grid_w) // 2
|
||||||
|
|
||||||
r, c = entity.pos
|
r, c = entity.pos
|
||||||
r, c = r - offset_r, c-offset_c
|
r, c = r - offset_r, c - offset_c
|
||||||
|
|
||||||
img = self.assets[entity.name.lower()]
|
img = self.assets[entity.name.lower()]
|
||||||
if entity.value_operation == OPACITY:
|
if entity.value_operation == OPACITY:
|
||||||
img.set_alpha(255*entity.value)
|
img.set_alpha(255 * entity.value)
|
||||||
elif entity.value_operation == SCALE:
|
elif entity.value_operation == SCALE:
|
||||||
re = img.get_rect()
|
re = img.get_rect()
|
||||||
img = pygame.transform.smoothscale(
|
img = pygame.transform.smoothscale(
|
||||||
img, (int(entity.value*re.width), int(entity.value*re.height))
|
img, (int(entity.value * re.width), int(entity.value * re.height))
|
||||||
)
|
)
|
||||||
o = self.cell_size//2
|
o = self.cell_size // 2
|
||||||
r_, c_ = r*self.cell_size + o, c*self.cell_size + o
|
r_, c_ = r * self.cell_size + o, c * self.cell_size + o
|
||||||
rect = img.get_rect()
|
rect = img.get_rect()
|
||||||
rect.centerx, rect.centery = c_, r_
|
rect.centerx, rect.centery = c_, r_
|
||||||
return dict(source=img, dest=rect)
|
return dict(source=img, dest=rect)
|
||||||
@@ -182,13 +182,13 @@ class Renderer:
|
|||||||
:rtype: List[dict]
|
:rtype: List[dict]
|
||||||
"""
|
"""
|
||||||
rects = []
|
rects = []
|
||||||
for i, j in product(range(-self.view_radius, self.view_radius+1),
|
for i, j in product(range(-self.view_radius, self.view_radius + 1),
|
||||||
range(-self.view_radius, self.view_radius+1)):
|
range(-self.view_radius, self.view_radius + 1)):
|
||||||
if view is not None:
|
if view is not None:
|
||||||
if bool(view[self.view_radius+j, self.view_radius+i]):
|
if bool(view[self.view_radius + j, self.view_radius + i]):
|
||||||
visibility_rect = bp['dest'].copy()
|
visibility_rect = bp['dest'].copy()
|
||||||
visibility_rect.centerx += i*self.cell_size
|
visibility_rect.centerx += i * self.cell_size
|
||||||
visibility_rect.centery += j*self.cell_size
|
visibility_rect.centery += j * self.cell_size
|
||||||
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
shape_surf = pygame.Surface(visibility_rect.size, pygame.SRCALPHA)
|
||||||
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
|
pygame.draw.rect(shape_surf, self.AGENT_VIEW_COLOR, shape_surf.get_rect())
|
||||||
shape_surf.set_alpha(64)
|
shape_surf.set_alpha(64)
|
||||||
@@ -222,7 +222,7 @@ class Renderer:
|
|||||||
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
|
RenderEntity(agent.state, (agent.pos[0] + 0.12, agent.pos[1]), 0.48, SCALE)
|
||||||
)
|
)
|
||||||
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
|
textsurface = self.font.render(str(agent.id), False, (0, 0, 0))
|
||||||
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0]-.07*self.cell_size,
|
text_blit = dict(source=textsurface, dest=(agent_blit['dest'].center[0] - .07 * self.cell_size,
|
||||||
agent_blit['dest'].center[1]))
|
agent_blit['dest'].center[1]))
|
||||||
blits += [agent_blit, state_blit, text_blit]
|
blits += [agent_blit, state_blit, text_blit]
|
||||||
|
|
||||||
@@ -265,7 +265,7 @@ class Renderer:
|
|||||||
self.screen.blit(img, img_rect)
|
self.screen.blit(img, img_rect)
|
||||||
|
|
||||||
# Render the probability next to the icon if it exists
|
# Render the probability next to the icon if it exists
|
||||||
if hasattr(action_entity, 'probability'):
|
if hasattr(action_entity, 'probability') and action_entity.probability != 0:
|
||||||
prob_text = font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
|
prob_text = font.render(f"{action_entity.probability:.2f}", True, (255, 0, 0))
|
||||||
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
|
prob_text_rect = prob_text.get_rect(top=img_rect.bottom, left=img_rect.left)
|
||||||
self.screen.blit(prob_text, prob_text_rect)
|
self.screen.blit(prob_text, prob_text_rect)
|
||||||
|
|||||||