mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-12 07:42:41 +02:00
Updated pomdp_r comment + Added some additional comments + Restructured experiment calling + Added Readme and requirements.txt
This commit is contained in:
@ -10,6 +10,7 @@ from marl_factory_grid.algorithms.rl.constants import Names
|
||||
nms = Names
|
||||
|
||||
def _as_torch(x):
|
||||
""" Helper function to convert different list types to a torch tensor """
|
||||
if isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x)
|
||||
elif isinstance(x, List):
|
||||
@ -20,15 +21,16 @@ def _as_torch(x):
|
||||
|
||||
|
||||
def transform_observations(env, ordered_dirt_piles, target_pile, cfg, n_agents):
|
||||
""" Requires that agent has observations -DirtPiles and -Self """
|
||||
agent_positions = [env.state.moving_entites[agent_idx].pos for agent_idx in range(n_agents)]
|
||||
""" Function that extracts local observations from global state
|
||||
Requires that agents have observations -DirtPiles and -Self (cf. environment configs) """
|
||||
agents_positions = get_agents_positions(env, n_agents)
|
||||
pile_observability_is_all = cfg[nms.ALGORITHM][nms.PILE_OBSERVABILITY] == nms.ALL
|
||||
if pile_observability_is_all:
|
||||
trans_obs = [torch.zeros(2+2*len(ordered_dirt_piles[0])) for _ in range(len(agent_positions))]
|
||||
trans_obs = [torch.zeros(2+2*len(ordered_dirt_piles[0])) for _ in range(len(agents_positions))]
|
||||
else:
|
||||
# Only show current target pile
|
||||
trans_obs = [torch.zeros(4) for _ in range(len(agent_positions))]
|
||||
for i, pos in enumerate(agent_positions):
|
||||
trans_obs = [torch.zeros(4) for _ in range(len(agents_positions))]
|
||||
for i, pos in enumerate(agents_positions):
|
||||
agent_x, agent_y = pos[0], pos[1]
|
||||
trans_obs[i][0] = agent_x
|
||||
trans_obs[i][1] = agent_y
|
||||
@ -45,6 +47,7 @@ def transform_observations(env, ordered_dirt_piles, target_pile, cfg, n_agents):
|
||||
|
||||
|
||||
def get_all_observations(env, cfg, n_agents):
|
||||
""" Helper function that returns all possible agent observations """
|
||||
dirt_piles_positions = [env.state.entities[nms.DIRT_PILES][pile_idx].pos for pile_idx in
|
||||
range(len(env.state.entities[nms.DIRT_PILES]))]
|
||||
if cfg[nms.ALGORITHM][nms.PILE_OBSERVABILITY] == nms.ALL:
|
||||
@ -76,41 +79,48 @@ def get_all_observations(env, cfg, n_agents):
|
||||
|
||||
|
||||
def get_dirt_piles_positions(env):
|
||||
""" Get positions of dirt piles on the map """
|
||||
return [env.state.entities[nms.DIRT_PILES][pile_idx].pos for pile_idx in range(len(env.state.entities[nms.DIRT_PILES]))]
|
||||
|
||||
|
||||
def get_agents_positions(env, n_agents):
|
||||
""" Get positions of agents on the map """
|
||||
return [env.state.moving_entites[agent_idx].pos for agent_idx in range(n_agents)]
|
||||
|
||||
|
||||
def get_ordered_dirt_piles(env, cleaned_dirt_piles, cfg, n_agents):
|
||||
""" Each agent can have its individual pile order """
|
||||
""" This function determines in which order the agents should clean the dirt piles
|
||||
Each agent can have its individual pile order """
|
||||
ordered_dirt_piles = [[] for _ in range(n_agents)]
|
||||
dirt_pile_positions = get_dirt_piles_positions(env)
|
||||
agent_positions = [env.state.moving_entites[agent_idx].pos for agent_idx in range(n_agents)]
|
||||
dirt_piles_positions = get_dirt_piles_positions(env)
|
||||
agents_positions = get_agents_positions(env, n_agents)
|
||||
for agent_idx in range(n_agents):
|
||||
if cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.FIXED, nms.AGENTS]:
|
||||
ordered_dirt_piles[agent_idx] = dirt_pile_positions
|
||||
ordered_dirt_piles[agent_idx] = dirt_piles_positions
|
||||
elif cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.SMART, nms.DYNAMIC]:
|
||||
# Calculate distances for remaining unvisited dirt piles
|
||||
remaining_target_piles = [pos for pos, value in cleaned_dirt_piles[agent_idx].items() if not value]
|
||||
pile_distances = {pos:0 for pos in remaining_target_piles}
|
||||
agent_pos = agent_positions[agent_idx]
|
||||
agent_pos = agents_positions[agent_idx]
|
||||
for pos in remaining_target_piles:
|
||||
pile_distances[pos] = np.abs(agent_pos[0] - pos[0]) + np.abs(agent_pos[1] - pos[1])
|
||||
|
||||
if cfg[nms.ALGORITHM][nms.PILE_ORDER] == nms.SMART:
|
||||
# Check if there is an agent in line with any of the remaining dirt piles
|
||||
# Check if there is an agent on the direct path to any of the remaining dirt piles
|
||||
for pile_pos in remaining_target_piles:
|
||||
for other_pos in agent_positions:
|
||||
for other_pos in agents_positions:
|
||||
if other_pos != agent_pos:
|
||||
if agent_pos[0] == other_pos[0] == pile_pos[0] or agent_pos[1] == other_pos[1] == pile_pos[1]:
|
||||
# Get the line between the agent and the goal
|
||||
# Get the line between the agent and the target
|
||||
path = bresenham(agent_pos[0], agent_pos[1], pile_pos[0], pile_pos[1])
|
||||
|
||||
# Check if the entity lies on the path between the agent and the goal
|
||||
# Check if the entity lies on the path between the agent and the target
|
||||
if other_pos in path:
|
||||
pile_distances[pile_pos] += np.abs(agent_pos[0] - other_pos[0]) + np.abs(agent_pos[1] - other_pos[1])
|
||||
|
||||
sorted_pile_distances = dict(sorted(pile_distances.items(), key=lambda item: item[1]))
|
||||
# Insert already visited dirt piles
|
||||
ordered_dirt_piles[agent_idx] = [pos for pos in dirt_pile_positions if pos not in remaining_target_piles]
|
||||
ordered_dirt_piles[agent_idx] = [pos for pos in dirt_piles_positions if pos not in remaining_target_piles]
|
||||
# Fill up with sorted positions
|
||||
for pos in sorted_pile_distances.keys():
|
||||
ordered_dirt_piles[agent_idx].append(pos)
|
||||
@ -145,6 +155,7 @@ def bresenham(x0, y0, x1, y1):
|
||||
|
||||
|
||||
def update_ordered_dirt_piles(agent_idx, cleaned_dirt_piles, ordered_dirt_piles, env, cfg, n_agents):
|
||||
""" Update the order of the remaining dirt piles """
|
||||
# Only update ordered_dirt_pile for agent that reached its target pile
|
||||
updated_ordered_dirt_piles = get_ordered_dirt_piles(env, cleaned_dirt_piles, cfg, n_agents)
|
||||
for i in range(len(ordered_dirt_piles[agent_idx])):
|
||||
@ -152,8 +163,10 @@ def update_ordered_dirt_piles(agent_idx, cleaned_dirt_piles, ordered_dirt_piles,
|
||||
|
||||
|
||||
def distribute_indices(env, cfg, n_agents):
|
||||
""" Distribute dirt piles evenly among the agents """
|
||||
indices = []
|
||||
n_dirt_piles = len(get_dirt_piles_positions(env))
|
||||
agents_positions = get_agents_positions(env, n_agents)
|
||||
if n_dirt_piles == 1 or cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.FIXED, nms.DYNAMIC, nms.SMART]:
|
||||
indices = [[0] for _ in range(n_agents)]
|
||||
else:
|
||||
@ -171,12 +184,11 @@ def distribute_indices(env, cfg, n_agents):
|
||||
# -> Starting with index 0 even piles are auxiliary piles, odd piles are primary piles
|
||||
if cfg[nms.ALGORITHM][nms.AUXILIARY_PILES] and nms.DOORS in env.state.entities.keys():
|
||||
door_positions = [door.pos for door in env.state.entities[nms.DOORS]]
|
||||
agent_positions = [env.state.moving_entites[agent_idx].pos for agent_idx in range(n_agents)]
|
||||
distances = {door_pos:[] for door_pos in door_positions}
|
||||
|
||||
# Calculate distance of every agent to every door
|
||||
for door_pos in door_positions:
|
||||
for agent_pos in agent_positions:
|
||||
for agent_pos in agents_positions:
|
||||
distances[door_pos].append(np.abs(door_pos[0] - agent_pos[0]) + np.abs(door_pos[1] - agent_pos[1]))
|
||||
|
||||
def duplicate_indices(lst, item):
|
||||
@ -213,6 +225,7 @@ def distribute_indices(env, cfg, n_agents):
|
||||
|
||||
|
||||
def update_target_pile(env, agent_idx, target_pile, indices, cfg):
|
||||
""" Get the next target pile for a given agent """
|
||||
if cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.FIXED, nms.DYNAMIC, nms.SMART]:
|
||||
if target_pile[agent_idx] + 1 < len(get_dirt_piles_positions(env)):
|
||||
target_pile[agent_idx] += 1
|
||||
@ -223,7 +236,8 @@ def update_target_pile(env, agent_idx, target_pile, indices, cfg):
|
||||
target_pile[agent_idx] += 1
|
||||
|
||||
|
||||
def door_is_close(env, agent_idx):
|
||||
def is_door_close(env, agent_idx):
|
||||
""" Checks whether the agent is close to a door """
|
||||
neighbourhood = [y for x in env.state.entities.neighboring_positions(env.state[nms.AGENT][agent_idx].pos)
|
||||
for y in env.state.entities.pos_dict[x] if nms.DOOR in y.name]
|
||||
if neighbourhood:
|
||||
@ -231,6 +245,7 @@ def door_is_close(env, agent_idx):
|
||||
|
||||
|
||||
def get_all_cleaned_dirt_piles(dirt_piles_positions, cleaned_dirt_piles, n_agents):
|
||||
""" Returns all dirt piles cleaned by any agent """
|
||||
meta_cleaned_dirt_piles = {pos: False for pos in dirt_piles_positions}
|
||||
for agent_idx in range(n_agents):
|
||||
for (pos, cleaned) in cleaned_dirt_piles[agent_idx].items():
|
||||
@ -240,6 +255,7 @@ def get_all_cleaned_dirt_piles(dirt_piles_positions, cleaned_dirt_piles, n_agent
|
||||
|
||||
|
||||
def handle_finished_episode(obs, agents, cfg):
|
||||
""" Finish up episode, calculate advantages and perform policy net and value net updates"""
|
||||
with torch.inference_mode(False):
|
||||
for ag_i, agent in enumerate(agents):
|
||||
# Get states, actions, rewards and values from rollout buffer
|
||||
@ -268,6 +284,7 @@ def handle_finished_episode(obs, agents, cfg):
|
||||
|
||||
|
||||
def split_into_chunks(data_tuple, cfg):
|
||||
""" Chunks episode data into approximately equal sized chunks to prevent system memory failure from overload """
|
||||
result = [data_tuple]
|
||||
chunk_size = cfg[nms.ALGORITHM][nms.CHUNK_EPISODE]
|
||||
if chunk_size > 0:
|
||||
@ -286,7 +303,8 @@ def split_into_chunks(data_tuple, cfg):
|
||||
return result
|
||||
|
||||
|
||||
def set_agent_spawnpoint(env, n_agents):
|
||||
def set_agents_spawnpoints(env, n_agents):
|
||||
""" Tell environment where the agents should spawn in the next episode """
|
||||
for agent_idx in range(n_agents):
|
||||
agent_name = list(env.state.agents_conf.keys())[agent_idx]
|
||||
current_pos_pointer = env.state.agents_conf[agent_name][nms.POS_POINTER]
|
||||
@ -299,6 +317,7 @@ def set_agent_spawnpoint(env, n_agents):
|
||||
|
||||
|
||||
def save_configs(results_path, cfg, factory_conf, eval_factory_conf):
|
||||
""" Save configurations for logging purposes """
|
||||
with open(f"{results_path}/MARL_config.txt", "w") as txt_file:
|
||||
txt_file.write(str(cfg))
|
||||
with open(f"{results_path}/train_env_config.txt", "w") as txt_file:
|
||||
@ -308,6 +327,7 @@ def save_configs(results_path, cfg, factory_conf, eval_factory_conf):
|
||||
|
||||
|
||||
def save_agent_models(results_path, agents):
|
||||
""" Save model parameters after training """
|
||||
for idx, agent in enumerate(agents):
|
||||
agent.pi.save_model_parameters(results_path)
|
||||
agent.vf.save_model_parameters(results_path)
|
||||
|
Reference in New Issue
Block a user