mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2026-01-15 23:41:39 +01:00
added rendering of start and target pos. changed file save location to match current run in study out
This commit is contained in:
@@ -572,7 +572,7 @@ class A2C:
|
|||||||
if self.cfg[nms.ENV]["save_and_log"]:
|
if self.cfg[nms.ENV]["save_and_log"]:
|
||||||
self.create_info_maps(env, used_actions)
|
self.create_info_maps(env, used_actions)
|
||||||
self.save_agent_models()
|
self.save_agent_models()
|
||||||
plot_action_maps(env, [self])
|
plot_action_maps(env, [self], self.results_path)
|
||||||
|
|
||||||
@torch.inference_mode(True)
|
@torch.inference_mode(True)
|
||||||
def eval_loop(self, n_episodes, render=False):
|
def eval_loop(self, n_episodes, render=False):
|
||||||
|
|||||||
@@ -186,6 +186,11 @@ class SpawnAgents(Rule):
|
|||||||
if isinstance(rule, marl_factory_grid.environment.rules.AgentSpawnRule):
|
if isinstance(rule, marl_factory_grid.environment.rules.AgentSpawnRule):
|
||||||
spawn_rule = rule.spawn_rule
|
spawn_rule = rule.spawn_rule
|
||||||
|
|
||||||
|
if not hasattr(state, 'agent_spawn_positions'):
|
||||||
|
state.agent_spawn_positions = []
|
||||||
|
else:
|
||||||
|
state.agent_spawn_positions.clear()
|
||||||
|
|
||||||
agents = state[c.AGENT]
|
agents = state[c.AGENT]
|
||||||
for agent_name, agent_conf in state.agents_conf.items():
|
for agent_name, agent_conf in state.agents_conf.items():
|
||||||
empty_positions = state.entities.empty_positions
|
empty_positions = state.entities.empty_positions
|
||||||
@@ -198,11 +203,14 @@ class SpawnAgents(Rule):
|
|||||||
if position := self._get_position(spawn_rule, positions, empty_positions, positions_pointer):
|
if position := self._get_position(spawn_rule, positions, empty_positions, positions_pointer):
|
||||||
assert state.check_pos_validity(position), 'smth went wrong....'
|
assert state.check_pos_validity(position), 'smth went wrong....'
|
||||||
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
agents.add_item(Agent(actions, observations, position, str_ident=agent_name, **other))
|
||||||
|
state.agent_spawn_positions.append(position)
|
||||||
elif positions:
|
elif positions:
|
||||||
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
raise ValueError(f'It was not possible to spawn an Agent on the available position: '
|
||||||
f'\n{agent_conf["positions"].copy()}')
|
f'\n{agent_conf["positions"].copy()}')
|
||||||
else:
|
else:
|
||||||
agents.add_item(Agent(actions, observations, empty_positions.pop(), str_ident=agent_name, **other))
|
chosen_position = empty_positions.pop()
|
||||||
|
agents.add_item(Agent(actions, observations, chosen_position, str_ident=agent_name, **other))
|
||||||
|
state.agent_spawn_positions.append(chosen_position)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_position(self, spawn_rule, positions, empty_positions, positions_pointer):
|
def _get_position(self, spawn_rule, positions, empty_positions, positions_pointer):
|
||||||
|
|||||||
@@ -92,7 +92,6 @@ class LevelParser(object):
|
|||||||
for symbol in symbols:
|
for symbol in symbols:
|
||||||
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
level_array = h.one_hot_level(self._parsed_level, symbol=symbol)
|
||||||
if np.any(level_array):
|
if np.any(level_array):
|
||||||
# TODO: Get rid of this!
|
|
||||||
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
e = e_class.from_coordinates(np.argwhere(level_array == c.VALUE_OCCUPIED_CELL).tolist(),
|
||||||
self.size, entity_kwargs=e_kwargs)
|
self.size, entity_kwargs=e_kwargs)
|
||||||
else:
|
else:
|
||||||
|
|||||||
BIN
marl_factory_grid/utils/plotting/action_assets/spawn_pos.png
Normal file
BIN
marl_factory_grid/utils/plotting/action_assets/spawn_pos.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 672 B |
BIN
marl_factory_grid/utils/plotting/action_assets/target_dirt.png
Normal file
BIN
marl_factory_grid/utils/plotting/action_assets/target_dirt.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 291 B |
@@ -14,6 +14,8 @@ from marl_factory_grid.utils.plotting.plotting_utils import prepare_plot
|
|||||||
from marl_factory_grid.utils.renderer import Renderer
|
from marl_factory_grid.utils.renderer import Renderer
|
||||||
from marl_factory_grid.utils.utility_classes import RenderEntity
|
from marl_factory_grid.utils.utility_classes import RenderEntity
|
||||||
|
|
||||||
|
from marl_factory_grid.modules.clean_up import constants as d
|
||||||
|
|
||||||
|
|
||||||
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
def plot_single_run(run_path: Union[str, PathLike], use_tex: bool = False, column_keys=None,
|
||||||
file_key: str = 'monitor', file_ext: str = 'pkl'):
|
file_key: str = 'monitor', file_ext: str = 'pkl'):
|
||||||
@@ -97,7 +99,7 @@ def plot_routes(factory, agents):
|
|||||||
top_actions = sorted(agent.action_probabilities.items(), key=lambda x: -x[1])[:4]
|
top_actions = sorted(agent.action_probabilities.items(), key=lambda x: -x[1])[:4]
|
||||||
else:
|
else:
|
||||||
# Handle deterministic agents by iterating through all actions in the list
|
# Handle deterministic agents by iterating through all actions in the list
|
||||||
top_actions = [(action, 1.0) for action in agent.action_list]
|
top_actions = [(action, 0) for action in agent.action_list]
|
||||||
|
|
||||||
for action, probability in top_actions:
|
for action, probability in top_actions:
|
||||||
if action.lower() in rotation_mapping:
|
if action.lower() in rotation_mapping:
|
||||||
@@ -121,7 +123,7 @@ def plot_routes(factory, agents):
|
|||||||
renderer.render_single_action_icons(action_entities) # move in/out loop for graph per agent or not
|
renderer.render_single_action_icons(action_entities) # move in/out loop for graph per agent or not
|
||||||
|
|
||||||
|
|
||||||
def plot_action_maps(factory, agents):
|
def plot_action_maps(factory, agents, result_path):
|
||||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
assets_path = {
|
assets_path = {
|
||||||
'green_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'green_arrow.png'),
|
'green_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'green_arrow.png'),
|
||||||
@@ -129,6 +131,8 @@ def plot_action_maps(factory, agents):
|
|||||||
'red_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'red_arrow.png'),
|
'red_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'red_arrow.png'),
|
||||||
'grey_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'grey_arrow.png'),
|
'grey_arrow': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'grey_arrow.png'),
|
||||||
'wall': os.path.join(base_dir, 'environment', 'assets', 'wall.png'),
|
'wall': os.path.join(base_dir, 'environment', 'assets', 'wall.png'),
|
||||||
|
'target_dirt': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'target_dirt.png'),
|
||||||
|
'spawn_pos': os.path.join(base_dir, 'utils', 'plotting', 'action_assets', 'spawn_pos.png')
|
||||||
}
|
}
|
||||||
renderer = Renderer(factory.map.level_shape, cell_size=80, custom_assets_path=assets_path)
|
renderer = Renderer(factory.map.level_shape, cell_size=80, custom_assets_path=assets_path)
|
||||||
|
|
||||||
@@ -139,8 +143,15 @@ def plot_action_maps(factory, agents):
|
|||||||
if hasattr(agent, 'action_probabilities'):
|
if hasattr(agent, 'action_probabilities'):
|
||||||
action_probabilities = unpack_action_probabilities(agent.action_probabilities)
|
action_probabilities = unpack_action_probabilities(agent.action_probabilities)
|
||||||
for action_map_index, probabilities_map in enumerate(action_probabilities[agent_index]):
|
for action_map_index, probabilities_map in enumerate(action_probabilities[agent_index]):
|
||||||
|
|
||||||
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
wall_entities = [RenderEntity(name='wall', probability=0, pos=np.array(pos)) for pos in wall_positions]
|
||||||
action_entities = list(wall_entities)
|
action_entities = list(wall_entities)
|
||||||
|
target_dirt_pos = factory.state.entities[d.DIRT][action_map_index].pos
|
||||||
|
action_entities.append(
|
||||||
|
RenderEntity(name='target_dirt', probability=0, pos=swap_coordinates(target_dirt_pos)))
|
||||||
|
action_entities.append(RenderEntity(name='spawn_pos', probability=0, pos=swap_coordinates(
|
||||||
|
factory.state.agent_spawn_positions[agent_index])))
|
||||||
|
|
||||||
for position, probabilities in probabilities_map.items():
|
for position, probabilities in probabilities_map.items():
|
||||||
if position not in wall_positions:
|
if position not in wall_positions:
|
||||||
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
if np.any(probabilities) > 0: # Ensure it's not all zeros which would indicate a wall
|
||||||
@@ -160,7 +171,7 @@ def plot_action_maps(factory, agents):
|
|||||||
)
|
)
|
||||||
action_entities.append(action_entity)
|
action_entities.append(action_entity)
|
||||||
|
|
||||||
renderer.render_multi_action_icons(action_entities)
|
renderer.render_multi_action_icons(action_entities, result_path)
|
||||||
|
|
||||||
|
|
||||||
def unpack_action_probabilities(action_probabilities):
|
def unpack_action_probabilities(action_probabilities):
|
||||||
@@ -178,12 +189,6 @@ def unpack_action_probabilities(action_probabilities):
|
|||||||
return unpacked
|
return unpacked
|
||||||
|
|
||||||
|
|
||||||
def load_action_map(file_path):
|
|
||||||
with open(file_path, 'r') as file:
|
|
||||||
action_map = json.load(file)
|
|
||||||
return action_map
|
|
||||||
|
|
||||||
|
|
||||||
def swap_coordinates(positions):
|
def swap_coordinates(positions):
|
||||||
"""
|
"""
|
||||||
Swaps x and y coordinates of single positions, lists or arrays
|
Swaps x and y coordinates of single positions, lists or arrays
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -272,10 +273,10 @@ class Renderer:
|
|||||||
pygame.display.flip() # Update the display with all new blits
|
pygame.display.flip() # Update the display with all new blits
|
||||||
self.save_screen("route_graph")
|
self.save_screen("route_graph")
|
||||||
|
|
||||||
def render_multi_action_icons(self, action_entities):
|
def render_multi_action_icons(self, action_entities, result_path):
|
||||||
"""
|
"""
|
||||||
Renders multiple action icons at the same position without overlap and arranges them based on direction, except for
|
Renders multiple action icons at the same position without overlap and arranges them based on direction, except
|
||||||
walls which cover the entire grid cell.
|
for walls, spawn and target positions, which cover the entire grid cell.
|
||||||
"""
|
"""
|
||||||
self.fill_bg()
|
self.fill_bg()
|
||||||
font = pygame.font.Font(None, 20)
|
font = pygame.font.Font(None, 20)
|
||||||
@@ -286,7 +287,7 @@ class Renderer:
|
|||||||
position_dict[tuple(entity.pos)].append(entity)
|
position_dict[tuple(entity.pos)].append(entity)
|
||||||
|
|
||||||
for position, entities in position_dict.items():
|
for position, entities in position_dict.items():
|
||||||
entity_size = self.cell_size // 2 # Adjust size to fit multiple entities for non-wall entities
|
entity_size = self.cell_size // 2
|
||||||
entities.sort(key=lambda x: x.rotation)
|
entities.sort(key=lambda x: x.rotation)
|
||||||
|
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
@@ -296,7 +297,7 @@ class Renderer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if the entity is a wall and adjust the size and position accordingly
|
# Check if the entity is a wall and adjust the size and position accordingly
|
||||||
if entity.name == 'wall':
|
if entity.name in ['wall', 'target_dirt', 'spawn_pos']:
|
||||||
img = pygame.transform.scale(img, (self.cell_size, self.cell_size))
|
img = pygame.transform.scale(img, (self.cell_size, self.cell_size))
|
||||||
img_rect = img.get_rect(center=(position[0] * self.cell_size + self.cell_size // 2,
|
img_rect = img.get_rect(center=(position[0] * self.cell_size + self.cell_size // 2,
|
||||||
position[1] * self.cell_size + self.cell_size // 2))
|
position[1] * self.cell_size + self.cell_size // 2))
|
||||||
@@ -326,17 +327,22 @@ class Renderer:
|
|||||||
self.screen.blit(prob_text, prob_text_rect)
|
self.screen.blit(prob_text, prob_text_rect)
|
||||||
|
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
self.save_screen("multi_action_graph")
|
self.save_screen("multi_action_graph", result_path)
|
||||||
|
|
||||||
def save_screen(self, filename):
|
def save_screen(self, filename, result_path):
|
||||||
"""
|
"""
|
||||||
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
|
Saves the current screen to a PNG file, appending a counter to ensure uniqueness.
|
||||||
:param filename: The base filename where to save the image.
|
:param filename: The base filename where to save the image.
|
||||||
:param agent_id: Unique identifier for the agent.
|
:param result_path: path to out folder
|
||||||
"""
|
"""
|
||||||
|
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
out_dir = os.path.join(base_dir, 'study_out', result_path)
|
||||||
|
if not os.path.exists(out_dir):
|
||||||
|
os.makedirs(out_dir)
|
||||||
unique_filename = f"{filename}_agent_{self.save_counter}.png"
|
unique_filename = f"{filename}_agent_{self.save_counter}.png"
|
||||||
self.save_counter += 1
|
self.save_counter += 1
|
||||||
pygame.image.save(self.screen, unique_filename)
|
full_path = os.path.join(out_dir, unique_filename)
|
||||||
|
pygame.image.save(self.screen, full_path)
|
||||||
print(f"Image saved as {unique_filename}")
|
print(f"Image saved as {unique_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ def single_agent_training(config_name):
|
|||||||
# Have consecutive episode for eval in single agent case
|
# Have consecutive episode for eval in single agent case
|
||||||
train_cfg["algorithm"]["pile_all_done"] = "all"
|
train_cfg["algorithm"]["pile_all_done"] = "all"
|
||||||
agent.eval_loop(10)
|
agent.eval_loop(10)
|
||||||
print(agent.action_probabilities)
|
|
||||||
|
|
||||||
|
|
||||||
def single_agent_eval(config_name, run):
|
def single_agent_eval(config_name, run):
|
||||||
|
|||||||
@@ -41,4 +41,4 @@ if __name__ == '__main__':
|
|||||||
print(f'Episode {episode} done...')
|
print(f'Episode {episode} done...')
|
||||||
break
|
break
|
||||||
|
|
||||||
plot_action_maps(factory, agents)
|
plot_routes(factory, agents)
|
||||||
|
|||||||
Reference in New Issue
Block a user