added rendering of start and target pos. changed file save location to match current run in study out

This commit is contained in:
Chanumask
2024-05-16 13:16:53 +02:00
parent cb990445ce
commit 1a8ca9110b
9 changed files with 40 additions and 23 deletions

View File

@ -572,7 +572,7 @@ class A2C:
if self.cfg[nms.ENV]["save_and_log"]:
self.create_info_maps(env, used_actions)
self.save_agent_models()
plot_action_maps(env, [self])
plot_action_maps(env, [self], self.results_path)
@torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False):

View File

@ -186,6 +186,11 @@ class SpawnAgents(Rule):
if isinstance(rule, marl_factory_grid.environment.rules.AgentSpawnRule):
spawn_rule = rule.spawn_rule
if not hasattr(state, 'agent_spawn_positions'):
state.agent_spawn_positions = []
else:
state.agent_spawn_positions.clear()
agents = state[c.AGENT]
for agent_name, agent_conf in state.agents_conf.items():
empty_positions = state.entities.empty_positions
@ -198,11 +203,14 @@ class SpawnAgents(Rule):
if position := self._get_position(spawn_rule, positions, empty_positions, positions_pointer):
assert state.check_pos_validity(position), 'smth went wrong....'
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
state.agent_spawn_positions.append(position)
elif positions:
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
f'\n{agent_conf["positions"].copy()}')
else:
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
chosen_position = empty_positions.pop()
agents.add_item(Agent(actions, observations, chosen_position, str_ident=agent_name, **other))
state.agent_spawn_positions.append(chosen_position)
return []
def _get_position(self, spawn_rule, positions, empty_positions, positions_pointer):

View File

@ -92,7 +92,6 @@ class LevelParser(object):
for symbol in symbols:
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
if np.any(level_array):
# TODO: Get rid of this!
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
self.size, entity_kwargs=e_kwargs)
else:

Binary file not shown.

After

Width:  |  Height:  |  Size: 672 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 291 B

View File

@ -14,6 +14,8 @@ from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
from marl_factory_grid.utils.renderer import Renderer
from marl_factory_grid.utils.utility_classes import RenderEntity
from marl_factory_grid.modules.clean_up import constants as d
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
file_key: str = 'monitor', file_ext: str = 'pkl'):
@ -97,7 +99,7 @@ def plot_routes(factory, agents):
top_actions = sorted(agent.action_probabilities.items(), key=lambda x: -x[1])[:4]
else:
# Handle deterministic agents by iterating through all actions in the list
top_actions = [(action, 1.0) for action in agent.action_list]
top_actions = [(action, 0) for action in agent.action_list]
for action, probability in top_actions:
if action.lower() in rotation_mapping:
@ -121,7 +123,7 @@ def plot_routes(factory, agents):
renderer.render_single_action_icons(action_entities) # move in/out loop for graph per agent or not
def plot_action_maps(factory, agents):
def plot_action_maps(factory, agents, result_path):
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
assets_path = {
'green_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'green_arrow.png'),
@ -129,6 +131,8 @@ def plot_action_maps(factory, agents):
'red_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'red_arrow.png'),
'grey_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'grey_arrow.png'),
'wall': os.path.join(base_dir, 'environment', 'assets', 'wall.png'),
'target_dirt': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'target_dirt.png'),
'spawn_pos': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'spawn_pos.png')
}
renderer = Renderer(factory.map.level_shape, cell_size=80, custom_assets_path=assets_path)
@ -139,8 +143,15 @@ def plot_action_maps(factory, agents):
if hasattr(agent, 'action_probabilities'):
action_probabilities = unpack_action_probabilities(agent.action_probabilities)
for action_map_index, probabilities_map in enumerate(action_probabilities[agent_index]):
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
action_entities = list(wall_entities)
target_dirt_pos = factory.state.entities[d.DIRT][action_map_index].pos
action_entities.append(
RenderEntity(name='target_dirt', probability=0, pos=swap_coordinates(target_dirt_pos)))
action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(
factory.state.agent_spawn_positions[agent_index])))
for position, probabilities in probabilities_map.items():
if position not in wall_positions:
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
@ -160,7 +171,7 @@ def plot_action_maps(factory, agents):
)
action_entities.append(action_entity)
renderer.render_multi_action_icons(action_entities)
renderer.render_multi_action_icons(action_entities, result_path)
def unpack_action_probabilities(action_probabilities):
@ -178,12 +189,6 @@ def unpack_action_probabilities(action_probabilities):
return unpacked
def load_action_map(file_path):
with open(file_path, 'r') as file:
action_map = json.load(file)
return action_map
def swap_coordinates(positions):
"""
Swaps x and y coordinates of single positions, lists or arrays

View File

@ -1,3 +1,4 @@
import os
import sys
from pathlib import Path
@ -272,10 +273,10 @@ class Renderer:
pygame.display.flip() # Update the display with all new blits
self.save_screen("route_graph")
def render_multi_action_icons(self, action_entities):
def render_multi_action_icons(self, action_entities, result_path):
"""
Renders multiple action icons at the same position without overlap and arranges them based on direction, except for
walls which cover the entire grid cell.
Renders multiple action icons at the same position without overlap and arranges them based on direction, except
for walls, spawn and target positions, which cover the entire grid cell.
"""
self.fill_bg()
font = pygame.font.Font(None, 20)
@ -286,7 +287,7 @@ class Renderer:
position_dict[tuple(entity.pos)].append(entity)
for position, entities in position_dict.items():
entity_size = self.cell_size // 2 # Adjust size to fit multiple entities for non-wall entities
entity_size = self.cell_size // 2
entities.sort(key=lambda x: x.rotation)
for entity in entities:
@ -296,7 +297,7 @@ class Renderer:
continue
# Check if the entity is a wall and adjust the size and position accordingly
if entity.name == 'wall':
if entity.name in ['wall', 'target_dirt', 'spawn_pos']:
img = pygame.transform.scale(img, (self.cell_size, self.cell_size))
img_rect = img.get_rect(center=(position[0] * self.cell_size + self.cell_size // 2,
position[1] * self.cell_size + self.cell_size // 2))
@ -326,17 +327,22 @@ class Renderer:
self.screen.blit(prob_text, prob_text_rect)
pygame.display.flip()
self.save_screen("multi_action_graph")
self.save_screen("multi_action_graph", result_path)
def save_screen(self, filename):
def save_screen(self, filename, result_path):
"""
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
:param filename: The base filename where to save the image.
:param agent_id: Unique identifier for the agent.
:param result_path: path to out folder
"""
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
out_dir = os.path.join(base_dir, 'study_out', result_path)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
unique_filename = f"{filename}_agent_{self.save_counter}.png"
self.save_counter += 1
pygame.image.save(self.screen, unique_filename)
full_path = os.path.join(out_dir, unique_filename)
pygame.image.save(self.screen, full_path)
print(f"Image saved as {unique_filename}")

View File

@ -18,7 +18,6 @@ def single_agent_training(config_name):
# Have consecutive episode for eval in single agent case
train_cfg["algorithm"]["pile_all_done"] = "all"
agent.eval_loop(10)
print(agent.action_probabilities)
def single_agent_eval(config_name, run):

View File

@ -41,4 +41,4 @@ if __name__ == '__main__':
print(f'Episode {episode} done...')
break
plot_action_maps(factory, agents)
plot_routes(factory, agents)