added changes from code submission branch and coin entity

This commit is contained in:
Chanumask
2024-09-06 11:01:42 +02:00
parent 33e40deecf
commit 5476f617c6
42 changed files with 1429 additions and 68 deletions

View File

@@ -0,0 +1 @@
from marl_factory_grid.algorithms.rl.memory import MARLActorCriticMemory

View File

@@ -0,0 +1,297 @@
import os
import torch
from typing import Union, List
import numpy as np
from tqdm import tqdm
from marl_factory_grid.algorithms.rl.base_a2c import PolicyGradient
from marl_factory_grid.algorithms.rl.constants import Names
from marl_factory_grid.algorithms.rl.utils import transform_observations, _as_torch, is_door_close, \
get_coin_piles_positions, update_target_pile, update_ordered_coin_piles, get_all_collected_coin_piles, \
distribute_indices, set_agents_spawnpoints, get_ordered_coin_piles, handle_finished_episode, save_configs, \
save_agent_models, get_all_observations, get_agents_positions
from marl_factory_grid.algorithms.utils import add_env_props
from marl_factory_grid.utils.plotting.plot_single_runs import plot_action_maps, plot_reward_development, \
create_info_maps
nms = Names
ListOrTensor = Union[List, torch.Tensor]
class A2C:
def __init__(self, train_cfg, eval_cfg):
self.results_path = None
self.agents = None
self.act_dim = None
self.obs_dim = None
self.factory = add_env_props(train_cfg)
self.eval_factory = add_env_props(eval_cfg)
self.__training = True
self.train_cfg = train_cfg
self.eval_cfg = eval_cfg
self.cfg = train_cfg
self.n_agents = train_cfg[nms.ENV][nms.N_AGENTS]
self.setup()
self.reward_development = []
self.action_probabilities = {agent_idx: [] for agent_idx in range(self.n_agents)}
def setup(self):
""" Initialize agents and create entry for run results according to configuration """
self.obs_dim = 2 + 2 * len(get_coin_piles_positions(self.factory)) if self.cfg[nms.ALGORITHM][
nms.PILE_OBSERVABILITY] == nms.ALL else 4
self.act_dim = 4 # The 4 movement directions
self.agents = [PolicyGradient(self.factory, agent_id=i, obs_dim=self.obs_dim, act_dim=self.act_dim) for i in
range(self.n_agents)]
if self.cfg[nms.ENV][nms.SAVE_AND_LOG]:
# Define study_out_path and check if it exists
base_dir = os.path.dirname(os.path.abspath(__file__)) # Directory of the script
study_out_path = os.path.join(base_dir, '../../../study_out')
study_out_path = os.path.abspath(study_out_path)
if not os.path.exists(study_out_path):
raise FileNotFoundError(f"The directory {study_out_path} does not exist.")
# Create results folder
runs = os.listdir(study_out_path)
run_numbers = [int(run[3:]) for run in runs if run[:3] == "run"]
next_run_number = max(run_numbers) + 1 if run_numbers else 0
self.results_path = os.path.join(study_out_path, f"run{next_run_number}")
os.mkdir(self.results_path)
# Save settings in results folder
save_configs(self.results_path, self.cfg, self.factory.conf, self.eval_factory.conf)
def set_cfg(self, eval=False):
if eval:
self.cfg = self.eval_cfg
else:
self.cfg = self.train_cfg
def load_agents(self, runs_list):
""" Initialize networks with parameters of already trained agents """
for idx, run in enumerate(runs_list):
run_path = f"./study_out/{run}"
self.agents[idx].pi.load_model_parameters(f"{run_path}/PolicyNet_model_parameters.pth")
self.agents[idx].vf.load_model_parameters(f"{run_path}/ValueNet_model_parameters.pth")
@torch.no_grad()
def train_loop(self):
""" Function for training agents """
env = self.factory
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
global_steps, episode = 0, 0
indices = distribute_indices(env, self.cfg, self.n_agents)
coin_piles_positions = get_coin_piles_positions(env)
target_pile = [partition[0] for partition in
indices] # list of pointers that point to the current target pile for each agent
collected_coin_piles = [{pos: False for pos in coin_piles_positions} for _ in range(self.n_agents)]
pbar = tqdm(total=max_steps)
while global_steps < max_steps:
_ = env.reset()
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
env.render()
set_agents_spawnpoints(env, self.n_agents)
ordered_coin_piles = get_ordered_coin_piles(env, collected_coin_piles, self.cfg, self.n_agents)
# Reset current target pile at episode begin if all piles have to be collected in one episode
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.ALL:
target_pile = [partition[0] for partition in indices]
collected_coin_piles = [{pos: False for pos in coin_piles_positions} for _ in range(self.n_agents)]
# Supply each agent with its local observation
obs = transform_observations(env, ordered_coin_piles, target_pile, self.cfg, self.n_agents)
done, rew_log = [False] * self.n_agents, 0
while not all(done):
action = self.use_door_or_move(env, obs, collected_coin_piles) \
if nms.DOORS in env.state.entities.keys() else self.get_actions(obs)
_, next_obs, reward, done, info = env.step(action)
next_obs = transform_observations(env, ordered_coin_piles, target_pile, self.cfg, self.n_agents)
# Handle case where agent is on field with coin
reward, done = self.handle_coin(env, collected_coin_piles, ordered_coin_piles, target_pile, indices,
reward, done)
if n_steps != 0 and (global_steps + 1) % n_steps == 0: done = True
done = [done] * self.n_agents if isinstance(done, bool) else done
for ag_i, agent in enumerate(self.agents):
if action[ag_i] in range(self.act_dim):
# Add agent results into respective rollout buffers
agent._episode[-1] = (next_obs[ag_i], action[ag_i], reward[ag_i], agent._episode[-1][-1])
# Visualize state update
if self.cfg[nms.ENV][nms.TRAIN_RENDER]: env.render()
obs = next_obs
if all(done): handle_finished_episode(obs, self.agents, self.cfg)
global_steps += 1
rew_log += sum(reward)
if global_steps >= max_steps: break
self.reward_development.append(rew_log)
episode += 1
pbar.update(global_steps - pbar.n)
pbar.close()
if self.cfg[nms.ENV][nms.SAVE_AND_LOG]:
plot_reward_development(self.reward_development, self.results_path)
create_info_maps(env, get_all_observations(env, self.cfg, self.n_agents),
get_coin_piles_positions(env), self.results_path, self.agents, self.act_dim, self)
save_agent_models(self.results_path, self.agents)
plot_action_maps(env, [self], self.results_path)
@torch.inference_mode(True)
def eval_loop(self, n_episodes):
""" Function for performing inference """
env = self.eval_factory
self.set_cfg(eval=True)
episode, results = 0, []
coin_piles_positions = get_coin_piles_positions(env)
indices = distribute_indices(env, self.cfg, self.n_agents)
target_pile = [partition[0] for partition in
indices] # list of pointers that point to the current target pile for each agent
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.DISTRIBUTED:
collected_coin_piles = [{coin_piles_positions[idx]: False for idx in indices[i]} for i in
range(self.n_agents)]
else: collected_coin_piles = [{pos: False for pos in coin_piles_positions} for _ in range(self.n_agents)]
while episode < n_episodes:
_ = env.reset()
set_agents_spawnpoints(env, self.n_agents)
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
# Don't render auxiliary piles
if self.cfg[nms.ALGORITHM][nms.AUXILIARY_PILES]:
auxiliary_piles = [pile for idx, pile in enumerate(env.state.entities[nms.COIN_PILES]) if
idx % 2 == 0]
for pile in auxiliary_piles:
pile.set_new_amount(0)
env.render()
env._renderer.fps = 5 # Slow down agent movement
# Reset current target pile at episode begin if all piles have to be collected in one episode
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] in [nms.ALL, nms.DISTRIBUTED, nms.SHARED]:
target_pile = [partition[0] for partition in indices]
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.DISTRIBUTED:
collected_coin_piles = [{coin_piles_positions[idx]: False for idx in indices[i]} for i in
range(self.n_agents)]
else: collected_coin_piles = [{pos: False for pos in coin_piles_positions} for _ in range(self.n_agents)]
ordered_coin_piles = get_ordered_coin_piles(env, collected_coin_piles, self.cfg, self.n_agents)
# Supply each agent with its local observation
obs = transform_observations(env, ordered_coin_piles, target_pile, self.cfg, self.n_agents)
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done):
action = self.use_door_or_move(env, obs, collected_coin_piles, det=True) \
if nms.DOORS in env.state.entities.keys() else self.execute_policy(obs, env,
collected_coin_piles) # zero exploration
_, next_obs, reward, done, info = env.step(action)
# Handle case where agent is on field with coin
reward, done = self.handle_coin(env, collected_coin_piles, ordered_coin_piles, target_pile, indices,
reward, done)
# Get transformed next_obs that might have been updated because of handle_coin
next_obs = transform_observations(env, ordered_coin_piles, target_pile, self.cfg, self.n_agents)
done = [done] * self.n_agents if isinstance(done, bool) else done
if self.cfg[nms.ENV][nms.EVAL_RENDER]: env.render()
obs = next_obs
episode += 1
# -------------------------------------- HELPER FUNCTIONS ------------------------------------------------- #
def get_actions(self, observations) -> ListOrTensor:
""" Given local observations, get actions for both agents """
actions = [agent.step(_as_torch(observations[ag_i]).view(-1).to(torch.float32)) for ag_i, agent in
enumerate(self.agents)]
return actions
def execute_policy(self, observations, env, collected_coin_piles) -> ListOrTensor:
""" Execute agent policies deterministically for inference """
actions = [agent.policy(_as_torch(observations[ag_i]).view(-1).to(torch.float32)) for ag_i, agent in
enumerate(self.agents)]
for agent_idx in range(self.n_agents):
if all(collected_coin_piles[agent_idx].values()):
actions[agent_idx] = np.array(next(
action_i for action_i, a in enumerate(env.state[nms.AGENT][agent_idx].actions) if
a.name == nms.NOOP))
return actions
def use_door_or_move(self, env, obs, collected_coin_piles, det=False):
""" Function that handles automatic actions like door opening and forced Noop"""
action = []
for agent_idx, agent in enumerate(self.agents):
agent_obs = _as_torch((obs)[agent_idx]).view(-1).to(torch.float32)
# Use Noop operation if agent already reached its target. (Only relevant for two-rooms setting)
if all(collected_coin_piles[agent_idx].values()):
action.append(next(action_i for action_i, a in enumerate(env.state[nms.AGENT][agent_idx].actions) if
a.name == nms.NOOP))
if not det:
# Include agent experience entry manually
agent._episode.append((None, None, None, agent.vf(agent_obs)))
else:
if door := is_door_close(env, agent_idx):
if door.is_closed:
action.append(next(
action_i for action_i, a in enumerate(env.state[nms.AGENT][agent_idx].actions) if
a.name == nms.USE_DOOR))
# Don't include action in agent experience
else:
if det: action.append(int(agent.pi(agent_obs, det=True)[0]))
else: action.append(int(agent.step(agent_obs)))
else:
if det: action.append(int(agent.pi(agent_obs, det=True)[0]))
else: action.append(int(agent.step(agent_obs)))
return action
def handle_coin(self, env, collected_coin_piles, ordered_coin_piles, target_pile, indices, reward, done):
""" Check if agent moved on field with coin. If that is the case collect coin automatically """
agents_positions = get_agents_positions(env, self.n_agents)
coin_piles_positions = get_coin_piles_positions(env)
if any([True for pos in agents_positions if pos in coin_piles_positions]):
# Only simulate collecting the coin
for idx, pos in enumerate(agents_positions):
if pos in collected_coin_piles[idx].keys() and not collected_coin_piles[idx][pos]:
# If coin piles should be collected in a specific order
if ordered_coin_piles[idx]:
if pos == ordered_coin_piles[idx][target_pile[idx]]:
reward[idx] += 50
collected_coin_piles[idx][pos] = True
# Set pointer to next coin pile
update_target_pile(env, idx, target_pile, indices, self.cfg)
update_ordered_coin_piles(idx, collected_coin_piles, ordered_coin_piles, env,
self.cfg, self.n_agents)
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.SINGLE:
done = True
if all(collected_coin_piles[idx].values()):
# Reset collected_coin_piles indicator
for pos in coin_piles_positions:
collected_coin_piles[idx][pos] = False
else:
reward[idx] += 50
collected_coin_piles[idx][pos] = True
# Indicate that renderer can hide coin pile
coin_at_position = env.state[nms.COIN_PILES].by_pos(pos)
coin_at_position[0].set_new_amount(0)
if self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] in [nms.ALL, nms.DISTRIBUTED]:
if all([all(collected_coin_piles[i].values()) for i in range(self.n_agents)]):
done = True
elif self.cfg[nms.ALGORITHM][nms.PILE_ALL_DONE] == nms.SHARED:
# End episode if both agents together have collected all coin piles
if all(get_all_collected_coin_piles(coin_piles_positions, collected_coin_piles, self.n_agents).values()):
done = True
return reward, done

View File

@@ -0,0 +1,112 @@
import numpy as np
import torch as th
import scipy as sp
from collections import deque
from torch import nn
cumulate_discount = lambda x, gamma: sp.signal.lfilter([1], [1, - gamma], x[::-1], axis=0)[::-1]
class Net(th.nn.Module):
def __init__(self, shape, activation, lr):
super().__init__()
self.net = th.nn.Sequential(*[layer
for io, a in zip(zip(shape[:-1], shape[1:]),
[activation] * (len(shape) - 2) + [th.nn.Identity])
for layer in [th.nn.Linear(*io), a()]])
self.optimizer = th.optim.Adam(self.net.parameters(), lr=lr)
# Initialize weights uniformly, so that for the policy net all actions have approximately the same
# probability in the beginning
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.uniform_(module.weight, a=-0.1, b=0.1)
if module.bias is not None:
nn.init.uniform_(module.bias, a=-0.1, b=0.1)
def save_model(self, path):
th.save(self.net, f"{path}/{self.__class__.__name__}_model.pth")
def save_model_parameters(self, path):
th.save(self.net.state_dict(), f"{path}/{self.__class__.__name__}_model_parameters.pth")
def load_model_parameters(self, path):
self.net.load_state_dict(th.load(path))
self.net.eval()
class ValueNet(Net):
def __init__(self, obs_dim, hidden_sizes=[64, 64], activation=th.nn.ReLU, lr=1e-3):
super().__init__([obs_dim] + hidden_sizes + [1], activation, lr)
def forward(self, obs): return self.net(obs)
def loss(self, states, returns): return ((returns - self(states)) ** 2).mean()
class PolicyNet(Net):
def __init__(self, obs_dim, act_dim, hidden_sizes=[64, 64], activation=th.nn.Tanh, lr=3e-4):
super().__init__([obs_dim] + hidden_sizes + [act_dim], activation, lr)
self.distribution = lambda obs: th.distributions.Categorical(logits=self.net(obs))
def forward(self, obs, act=None, det=False):
"""Given an observation: Returns policy distribution and probablilty for a given action
or Returns a sampled action and its corresponding probablilty"""
pi = self.distribution(obs)
if act is not None: return pi, pi.log_prob(act)
act = self.net(obs).argmax() if det else pi.sample() # sample from the learned distribution
return act, pi.log_prob(act)
def loss(self, states, actions, advantages):
_, logp = self.forward(states, actions)
loss = -(logp * advantages).mean()
return loss
class PolicyGradient:
""" Autonomous agent using vanilla policy gradient. """
def __init__(self, env, seed=42, gamma=0.99, agent_id=0, act_dim=None, obs_dim=None):
self.env = env
self.gamma = gamma # Setup env and discount
th.manual_seed(seed)
np.random.seed(seed) # Seed Torch, numpy and gym
# Keep track of previous rewards and performed steps to calcule the mean Return metric
self._episode, self.ep_returns, self.num_steps = [], deque(maxlen=100), 0
# Get observation and action shapes
if not obs_dim:
obs_size = env.observation_space.shape if len(env.state.entities.by_name("Agents")) == 1 \
else env.observation_space[agent_id].shape # Single agent case vs. multi-agent case
obs_dim = np.prod(obs_size)
if not act_dim:
act_dim = env.action_space[agent_id].n
self.vf = ValueNet(obs_dim) # Setup Value Network (Critic)
self.pi = PolicyNet(obs_dim, act_dim) # Setup Policy Network (Actor)
def step(self, obs):
""" Given an observation, get action and probs from policy and values from critic"""
with th.no_grad():
(a, _), v = self.pi(obs), self.vf(obs)
self._episode.append((None, None, None, v))
return a.numpy()
def policy(self, obs, det=True):
return self.pi(obs, det=det)[0].numpy()
def finish_episode(self):
"""Process self._episode & reset self.env, Returns (s,a,G,V)-Tuple and new inital state"""
s, a, r, v = (np.array(e) for e in zip(*self._episode)) # Get trajectories from rollout
self.ep_returns.append(sum(r))
self._episode = [] # Add episode return to buffer & reset
return s, a, r, v # state, action, Return, Value Tensors
def train(self, states, actions, returns, advantages): # Update policy weights
self.pi.optimizer.zero_grad()
self.vf.optimizer.zero_grad() # Reset optimizer
states = states.flatten(1, -1) # Reduce dimensionality to rollout_dim x input_dim
policy_loss = self.pi.loss(states, actions, advantages) # Calculate Policy loss
policy_loss.backward()
self.pi.optimizer.step() # Apply Policy loss
value_loss = self.vf.loss(states, returns) # Calculate Value loss
value_loss.backward()
self.vf.optimizer.step() # Apply Value loss

View File

@@ -0,0 +1,242 @@
import torch
from typing import Union, List, Dict
import numpy as np
from torch.distributions import Categorical
from marl_factory_grid.algorithms.rl.memory import MARLActorCriticMemory
from marl_factory_grid.algorithms.utils import add_env_props, instantiate_class
from pathlib import Path
import pandas as pd
from collections import deque
class Names:
REWARD = 'reward'
DONE = 'done'
ACTION = 'action'
OBSERVATION = 'observation'
LOGITS = 'logits'
HIDDEN_ACTOR = 'hidden_actor'
HIDDEN_CRITIC = 'hidden_critic'
AGENT = 'agent'
ENV = 'env'
ENV_NAME = 'env_name'
N_AGENTS = 'n_agents'
ALGORITHM = 'algorithm'
MAX_STEPS = 'max_steps'
N_STEPS = 'n_steps'
BUFFER_SIZE = 'buffer_size'
CRITIC = 'critic'
BATCH_SIZE = 'bnatch_size'
N_ACTIONS = 'n_actions'
TRAIN_RENDER = 'train_render'
EVAL_RENDER = 'eval_render'
nms = Names
ListOrTensor = Union[List, torch.Tensor]
class BaseActorCritic:
def __init__(self, cfg):
self.factory = add_env_props(cfg)
self.__training = True
self.cfg = cfg
self.n_agents = cfg[nms.AGENT][nms.N_AGENTS]
self.reset_memory_after_epoch = True
self.setup()
def setup(self):
self.net = instantiate_class(self.cfg[nms.AGENT])
self.optimizer = torch.optim.RMSprop(self.net.parameters(), lr=3e-4, eps=1e-5)
@classmethod
def _as_torch(cls, x):
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
elif isinstance(x, List):
return torch.tensor(x)
elif isinstance(x, (int, float)):
return torch.tensor([x])
return x
def train(self):
self.__training = False
networks = [self.net] if not isinstance(self.net, List) else self.net
for net in networks:
net.train()
def eval(self):
self.__training = False
networks = [self.net] if not isinstance(self.net, List) else self.net
for net in networks:
net.eval()
def load_state_dict(self, path: Path):
pass
def get_actions(self, out) -> ListOrTensor:
actions = [Categorical(logits=logits).sample().item() for logits in out[nms.LOGITS]]
return actions
def init_hidden(self) -> Dict[str, ListOrTensor]:
pass
def forward(self,
observations: ListOrTensor,
actions: ListOrTensor,
hidden_actor: ListOrTensor,
hidden_critic: ListOrTensor
) -> Dict[str, ListOrTensor]:
pass
@torch.no_grad()
def train_loop(self, checkpointer=None):
env = self.factory
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
env.render()
n_steps, max_steps = [self.cfg[nms.ALGORITHM][k] for k in [nms.N_STEPS, nms.MAX_STEPS]]
tm = MARLActorCriticMemory(self.n_agents, self.cfg[nms.ALGORITHM].get(nms.BUFFER_SIZE, n_steps))
global_steps, episode, df_results = 0, 0, []
reward_queue = deque(maxlen=2000)
while global_steps < max_steps:
obs = env.reset()
obs = list(obs.values())
last_hiddens = self.init_hidden()
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log = [False] * self.n_agents, 0
if self.reset_memory_after_epoch:
tm.reset()
tm.add(observation=obs, action=last_action,
logits=torch.zeros(self.n_agents, 1, self.cfg[nms.AGENT][nms.N_ACTIONS]),
values=torch.zeros(self.n_agents, 1), reward=reward, done=done, **last_hiddens)
while not all(done):
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
_, next_obs, reward, done, info = env.step(action)
done = [done] * self.n_agents if isinstance(done, bool) else done
if self.cfg[nms.ENV][nms.TRAIN_RENDER]:
env.render()
last_hiddens = dict(hidden_actor=out[nms.HIDDEN_ACTOR],
hidden_critic=out[nms.HIDDEN_CRITIC])
logits = torch.stack([tensor.squeeze(0) for tensor in out.get(nms.LOGITS, None)], dim=0)
values = torch.stack([tensor.squeeze(0) for tensor in out.get(nms.CRITIC, None)], dim=0)
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=logits, values=values,
**last_hiddens)
obs = next_obs
last_action = action
if (global_steps+1) % n_steps == 0 or all(done):
with torch.inference_mode(False):
self.learn(tm)
global_steps += 1
rew_log += sum(reward)
reward_queue.extend(reward)
if checkpointer is not None:
checkpointer.step([
(f'agent#{i}', agent)
for i, agent in enumerate([self.net] if not isinstance(self.net, List) else self.net)
])
if global_steps >= max_steps:
break
if global_steps%100 == 0:
print(f'reward at episode: {episode} = {rew_log}')
episode += 1
df_results.append([episode, rew_log, *reward])
df_results = pd.DataFrame(df_results,
columns=['steps', 'reward', *[f'agent#{i}' for i in range(self.n_agents)]]
)
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False):
env = self.factory
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
env.render()
episode, results = 0, []
while episode < n_episodes:
obs = env.reset()
obs = list(obs.values())
last_hiddens = self.init_hidden()
last_action, reward = [-1] * self.n_agents, [0.] * self.n_agents
done, rew_log, eps_rew = [False] * self.n_agents, 0, torch.zeros(self.n_agents)
while not all(done):
out = self.forward(obs, last_action, **last_hiddens)
action = self.get_actions(out)
_, next_obs, reward, done, info = env.step(action)
if self.cfg[nms.ENV][nms.EVAL_RENDER]:
env.render()
if isinstance(done, bool):
done = [done] * obs[0].shape[0]
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get(nms.HIDDEN_ACTOR, None),
hidden_critic=out.get(nms.HIDDEN_CRITIC, None)
)
eps_rew += torch.tensor(reward)
results.append(eps_rew.tolist() + [sum(eps_rew).item()] + [episode])
episode += 1
agent_columns = [f'agent#{i}' for i in range(self.cfg[nms.ENV][nms.N_AGENTS])]
results = pd.DataFrame(results, columns=agent_columns + ['sum', 'episode'])
results = pd.melt(results, id_vars=['episode'], value_vars=agent_columns + ['sum'],
value_name='reward', var_name='agent')
return results
@staticmethod
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
if gae_coef <= 0:
return tds
gae = torch.zeros_like(tds[:, -1])
gaes = []
for t in range(tds.shape[1]-1, -1, -1):
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
gaes.insert(0, gae)
gaes = torch.stack(gaes, dim=1)
return gaes
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
out = network(obs, actions, tm.hidden_actor[:, 0].squeeze(0), tm.hidden_critic[:, 0].squeeze(0))
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
critic = out[nms.CRITIC]
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
value_loss = advantages.pow(2).mean(-1) # n_agent
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
# weighted loss
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()
def learn(self, tm: MARLActorCriticMemory, **kwargs):
loss = self.actor_critic(tm, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
# remove next_obs, will be added in next iter
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()

View File

@@ -0,0 +1,34 @@
agent:
classname: marl_factory_grid.algorithms.rl.networks.RecurrentAC
n_agents: 2
obs_emb_size: 96
action_emb_size: 16
hidden_size_actor: 64
hidden_size_critic: 64
use_agent_embedding: False
env:
classname: marl_factory_grid.configs.custom
env_name: "custom/MultiAgentConfigs/dirt_quadrant_train_config"
n_agents: 2
max_steps: 250
pomdp_r: 2
stack_n_frames: 0
individual_rewards: True
train_render: False
eval_render: True
save_and_log: True
record: False
method: marl_factory_grid.algorithms.rl.LoopSEAC
algorithm:
gamma: 0.99
entropy_coef: 0.01
vf_coef: 0.05
n_steps: 0 # How much experience should be sampled at most (n-TD) until the next value and policy update is performed. Default 0: MC
max_steps: 200000
advantage: "Advantage-AC" # Options: "Advantage-AC", "TD-Advantage-AC", "Reinforce"
pile-order: "dynamic" # Use "dynamic" to see emergent phenomenon and "smart" to prevent it
pile-observability: "single" # Options: "single", "all"
pile_all_done: "shared" # Options: "single", "all" ("single" for training, "all" for eval), "shared"
auxiliary_piles: False # Option that is only considered when pile-order = "agents"
chunk-episode: 20000 # Chunk size. (0 = update networks with full episode at once)

View File

@@ -0,0 +1,35 @@
agent:
classname: marl_factory_grid.algorithms.rl.networks.RecurrentAC
n_agents: 2
obs_emb_size: 96
action_emb_size: 16
hidden_size_actor: 64
hidden_size_critic: 64
use_agent_embedding: False
env:
classname: marl_factory_grid.configs.custom
env_name: "custom/two_rooms_one_door_modified_train_config"
n_agents: 2
max_steps: 250
pomdp_r: 2
stack_n_frames: 0
individual_rewards: True
train_render: False
eval_render: True
save_and_log: True
record: False
method: marl_factory_grid.algorithms.rl.LoopSEAC
algorithm:
gamma: 0.99
entropy_coef: 0.01
vf_coef: 0.05
n_steps: 0 # How much experience should be sampled at most (n-TD) until the next value and policy update is performed. Default 0: MC
max_steps: 260000
advantage: "Advantage-AC" # Options: "Advantage-AC", "TD-Advantage-AC", "Reinforce"
pile-order: "agents" # Options: "fixed", "random", "none", "agents", "dynamic", "smart" (Use "fixed", "random" and "none" for single agent training and the other for multi agent inference)
pile-observability: "single" # Options: "single", "all"
pile_all_done: "distributed" # Options: "single", "all" ("single" for training, "all" and "distributed" for eval)
auxiliary_piles: True # Use True to see emergent phenomenon and False to prevent it
chunk-episode: 20000 # Chunk size. (0 = update networks with full episode at once)

View File

@@ -0,0 +1,34 @@
agent:
classname: marl_factory_grid.algorithms.rl.networks.RecurrentAC
n_agents: 1
obs_emb_size: 96
action_emb_size: 16
hidden_size_actor: 64
hidden_size_critic: 64
use_agent_embedding: False
env:
classname: marl_factory_grid.configs.custom
env_name: "custom/dirt_quadrant_train_config"
n_agents: 1
max_steps: 250
pomdp_r: 2
stack_n_frames: 0
individual_rewards: True
train_render: False
eval_render: True
save_and_log: True
record: False
method: marl_factory_grid.algorithms.rl.LoopSEAC
algorithm:
gamma: 0.99
entropy_coef: 0.01
vf_coef: 0.05
n_steps: 0 # How much experience should be sampled at most (n-TD) until the next value and policy update is performed. Default 0: MC
max_steps: 240000
advantage: "Advantage-AC" # Options: "Advantage-AC", "TD-Advantage-AC", "Reinforce"
pile-order: "fixed" # Options: "fixed", "random", "none", "agents", "dynamic", "smart" (Use "fixed", "random" and "none" for single agent training and the other for multi agent inference)
pile-observability: "single" # Options: "single", "all"
pile_all_done: "single" # Options: "single", "all" ("single" for training, "all" for eval)
auxiliary_piles: False # Option that is only considered when pile-order = "agents"
chunk-episode: 20000 # Chunk size. (0 = update networks with full episode at once)

View File

@@ -0,0 +1,8 @@
marl_factory_grid>environment>rules.py#SpawnEntity.on_reset()
marl_factory_grid>environment>rewards.py
marl_factory_grid>modules>clean_up>groups.py#DirtPiles.trigger_spawn()
marl_factory_grid>environment>rules.py#AgentSpawnRule
marl_factory_grid>utils>states.py#GameState.__init__()
marl_factory_grid>environment>factory.py>Factory#render
marl_factory_grid>environment>factory.py>Factory#set_recorder
marl_factory_grid>utils>renderer.py>Renderer#render

View File

@@ -0,0 +1,35 @@
agent:
classname: marl_factory_grid.algorithms.rl.networks.RecurrentAC
n_agents: 1
obs_emb_size: 96
action_emb_size: 16
hidden_size_actor: 64
hidden_size_critic: 64
use_agent_embedding: False
env:
classname: marl_factory_grid.configs.custom
env_name: "custom/two_rooms_one_door_modified_train_config"
n_agents: 1
max_steps: 250
pomdp_r: 2
stack_n_frames: 0
individual_rewards: True
train_render: False
eval_render: True
save_and_log: False
record: False
method: marl_factory_grid.algorithms.rl.LoopSEAC
algorithm:
gamma: 0.99
entropy_coef: 0.01
vf_coef: 0.05
n_steps: 0 # How much experience should be sampled at most (n-TD) until the next value and policy update is performed. Default 0: MC
max_steps: 260000
advantage: "Advantage-AC" # Options: "Advantage-AC", "TD-Advantage-AC", "Reinforce"
pile-order: "fixed" # Options: "fixed", "random", "none", "agents", "dynamic", "smart" (Use "fixed", "random" and "none" for single agent training and the other for multi agent inference)
pile-observability: "single" # Options: "single", "all"
pile_all_done: "single" # Options: "single", "all" ("single" for training, "all" for eval)
auxiliary_piles: False # Option that is only considered when pile-order = "agents"
chunk-episode: 20000 # Chunk size. (0 = update networks with full episode at once)

View File

@@ -0,0 +1,37 @@
class Names:
ENV = 'env'
ENV_NAME = 'env_name'
N_AGENTS = 'n_agents'
ALGORITHM = 'algorithm'
MAX_STEPS = 'max_steps'
N_STEPS = 'n_steps'
TRAIN_RENDER = 'train_render'
EVAL_RENDER = 'eval_render'
AGENT = 'Agent'
PILE_OBSERVABILITY = 'pile-observability'
PILE_ORDER = 'pile-order'
ALL = 'all'
FIXED = 'fixed'
AGENTS = 'agents'
DYNAMIC = 'dynamic'
SMART = 'smart'
DIRT_PILES = 'DirtPiles'
COIN_PILES = 'CoinPiles'
AUXILIARY_PILES = "auxiliary_piles"
DOORS = 'Doors'
DOOR = 'Door'
GAMMA = 'gamma'
ADVANTAGE = 'advantage'
REINFORCE = 'reinforce'
ADVANTAGE_AC = "Advantage-AC"
TD_ADVANTAGE_AC = "TD-Advantage-AC"
CHUNK_EPISODE = 'chunk-episode'
POS_POINTER = 'pos_pointer'
POSITIONS = 'positions'
SAVE_AND_LOG = 'save_and_log'
NOOP = 'Noop'
USE_DOOR = 'use_door'
PILE_ALL_DONE = 'pile_all_done'
SINGLE = 'single'
DISTRIBUTED = 'distributed'
SHARED = 'shared'

View File

@@ -0,0 +1,57 @@
import torch
from marl_factory_grid.algorithms.rl.base_ac import BaseActorCritic, nms
from marl_factory_grid.algorithms.utils import instantiate_class
from pathlib import Path
from natsort import natsorted
from marl_factory_grid.algorithms.rl.memory import MARLActorCriticMemory
class LoopIAC(BaseActorCritic):
def __init__(self, cfg):
super(LoopIAC, self).__init__(cfg)
def setup(self):
self.net = [
instantiate_class(self.cfg[nms.AGENT]) for _ in range(self.n_agents)
]
self.optimizer = [
torch.optim.RMSprop(self.net[ag_i].parameters(), lr=3e-4, eps=1e-5) for ag_i in range(self.n_agents)
]
def load_state_dict(self, path: Path):
paths = natsorted(list(path.glob('*.pt')))
for path, net in zip(paths, self.net):
net.load_state_dict(torch.load(path))
@staticmethod
def merge_dicts(ds): # todo could be recursive for more than 1 hierarchy
d = {}
for k in ds[0].keys():
d[k] = [d[k] for d in ds]
return d
def init_hidden(self):
ha = [net.init_hidden_actor() for net in self.net]
hc = [net.init_hidden_critic() for net in self.net]
return dict(hidden_actor=ha, hidden_critic=hc)
def forward(self, observations, actions, hidden_actor, hidden_critic):
outputs = [
net(
self._as_torch(observations[ag_i]).unsqueeze(0).unsqueeze(0), # agent x time
self._as_torch(actions[ag_i]).unsqueeze(0),
hidden_actor[ag_i],
hidden_critic[ag_i]
) for ag_i, net in enumerate(self.net)
]
return self.merge_dicts(outputs)
def learn(self, tms: MARLActorCriticMemory, **kwargs):
for ag_i in range(self.n_agents):
tm, net = tms(ag_i), self.net[ag_i]
loss = self.actor_critic(tm, net, **self.cfg[nms.ALGORITHM], **kwargs)
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
self.optimizer[ag_i].step()

View File

@@ -0,0 +1,66 @@
from marl_factory_grid.algorithms.rl.base_ac import Names as nms
from marl_factory_grid.algorithms.rl.snac import LoopSNAC
from marl_factory_grid.algorithms.rl.memory import MARLActorCriticMemory
import torch
from torch.distributions import Categorical
from marl_factory_grid.algorithms.utils import instantiate_class
class LoopMAPPO(LoopSNAC):
def __init__(self, *args, **kwargs):
super(LoopMAPPO, self).__init__(*args, **kwargs)
self.reset_memory_after_epoch = False
def setup(self):
self.net = instantiate_class(self.cfg[nms.AGENT])
self.optimizer = torch.optim.Adam(self.net.parameters(), lr=3e-4, eps=1e-5)
def learn(self, tm: MARLActorCriticMemory, **kwargs):
if len(tm) >= self.cfg['algorithm']['buffer_size']:
# only learn when buffer is full
for batch_i in range(self.cfg['algorithm']['n_updates']):
batch = tm.chunk_dataloader(chunk_len=self.cfg['algorithm']['n_steps'],
k=self.cfg['algorithm']['batch_size'])
loss = self.mappo(batch, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()
def monte_carlo_returns(self, rewards, done, gamma):
rewards_ = []
discounted_reward = torch.zeros_like(rewards[:, -1])
for t in range(rewards.shape[1]-1, -1, -1):
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
rewards_.insert(0, discounted_reward)
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def mappo(self, batch, network, gamma, entropy_coef, vf_coef, clip_range, **__):
out = network(batch[nms.OBSERVATION], batch[nms.ACTION], batch[nms.HIDDEN_ACTOR], batch[nms.HIDDEN_CRITIC])
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
old_log_probs = torch.log_softmax(batch[nms.LOGITS], -1)
old_log_probs = torch.gather(old_log_probs, index=batch[nms.ACTION][:, 1:].unsqueeze(-1), dim=-1).squeeze()
# monte carlo returns
mc_returns = self.monte_carlo_returns(batch[nms.REWARD], batch[nms.DONE], gamma)
mc_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-8) # todo: norm across agent ok?
advantages = mc_returns - out[nms.CRITIC][:, :-1]
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=batch[nms.ACTION][:, 1:].unsqueeze(-1)).squeeze()
ratio = (log_ap - old_log_probs).exp()
surr1 = ratio * advantages.detach()
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
policy_loss = -torch.min(surr1, surr2).mean(-1)
# entropy & value loss
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
value_loss = advantages.pow(2).mean(-1) # n_agent
# weighted loss
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()

View File

@@ -0,0 +1,221 @@
import numpy as np
from collections import deque
import torch
from typing import Union
from torch import Tensor
from torch.utils.data import Dataset, ConcatDataset
import random
class ActorCriticMemory(object):
def __init__(self, capacity=10):
self.capacity = capacity
self.reset()
def reset(self):
self.__actions = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__hidden_actor = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__hidden_critic = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__states = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__rewards = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__dones = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__logits = LazyTensorFiFoQueue(maxlen=self.capacity+1)
self.__values = LazyTensorFiFoQueue(maxlen=self.capacity+1)
def __len__(self):
return len(self.__rewards) - 1
@property
def observation(self, sls=slice(0, None)): # add time dimension through stacking
return self.__states[sls].unsqueeze(0) # 1 x time x hidden dim
@property
def hidden_actor(self, sls=slice(0, None)): # 1 x n_layers x dim
return self.__hidden_actor[sls].unsqueeze(0) # 1 x time x n_layers x dim
@property
def hidden_critic(self, sls=slice(0, None)): # 1 x n_layers x dim
return self.__hidden_critic[sls].unsqueeze(0) # 1 x time x n_layers x dim
@property
def reward(self, sls=slice(0, None)):
return self.__rewards[sls].squeeze().unsqueeze(0) # 1 x time
@property
def action(self, sls=slice(0, None)):
return self.__actions[sls].long().squeeze().unsqueeze(0) # 1 x time
@property
def done(self, sls=slice(0, None)):
return self.__dones[sls].float().squeeze().unsqueeze(0) # 1 x time
@property
def logits(self, sls=slice(0, None)): # assumes a trailing 1 for time dimension - common when using output from NN
return self.__logits[sls].squeeze().unsqueeze(0) # 1 x time x actions
@property
def values(self, sls=slice(0, None)):
return self.__values[sls].squeeze().unsqueeze(0) # 1 x time x actions
def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
def add_hidden_actor(self, hidden: Tensor):
# layers x hidden dim
self.__hidden_actor.append(hidden)
def add_hidden_critic(self, hidden: Tensor):
# layers x hidden dim
self.__hidden_critic.append(hidden)
def add_action(self, action: Union[int, Tensor]):
if not isinstance(action, Tensor):
action = torch.tensor(action)
self.__actions.append(action)
def add_reward(self, reward: Union[float, Tensor]):
if not isinstance(reward, Tensor):
reward = torch.tensor(reward)
self.__rewards.append(reward)
def add_done(self, done: bool):
if not isinstance(done, Tensor):
done = torch.tensor(done)
self.__dones.append(done)
def add_logits(self, logits: Tensor):
self.__logits.append(logits)
def add_values(self, values: Tensor):
self.__values.append(values)
def add(self, **kwargs):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
func(self, v)
class MARLActorCriticMemory(object):
def __init__(self, n_agents, capacity):
self.n_agents = n_agents
self.memories = [
ActorCriticMemory(capacity) for _ in range(n_agents)
]
def __call__(self, agent_i):
return self.memories[agent_i]
def __len__(self):
return len(self.memories[0]) # todo add assertion check!
def reset(self):
for mem in self.memories:
mem.reset()
def add(self, **kwargs):
for agent_i in range(self.n_agents):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
func(self.memories[agent_i], v[agent_i])
def __getattr__(self, attr):
all_attrs = [getattr(mem, attr) for mem in self.memories]
return torch.cat(all_attrs, 0) # agent x time ...
def chunk_dataloader(self, chunk_len, k):
datasets = [ExperienceChunks(mem, chunk_len, k) for mem in self.memories]
dataset = ConcatDataset(datasets)
data = [dataset[i] for i in range(len(dataset))]
data = custom_collate_fn(data)
return data
def custom_collate_fn(batch):
elem = batch[0]
return {key: torch.cat([d[key] for d in batch], dim=0) for key in elem}
class ExperienceChunks(Dataset):
def __init__(self, memory, chunk_len, k):
assert chunk_len <= len(memory), 'chunk_len cannot be longer than the size of the memory'
self.memory = memory
self.chunk_len = chunk_len
self.k = k
@property
def whitelist(self):
whitelist = torch.ones(len(self.memory) - self.chunk_len)
for d in self.memory.done.squeeze().nonzero().flatten():
whitelist[max((0, d-self.chunk_len-1)):d+2] = 0
whitelist[0] = 0
return whitelist.tolist()
def sample(self, start=1):
cl = self.chunk_len
sample = dict(observation=self.memory.observation[:, start:start+cl+1],
action=self.memory.action[:, start-1:start+cl],
hidden_actor=self.memory.hidden_actor[:, start-1],
hidden_critic=self.memory.hidden_critic[:, start-1],
reward=self.memory.reward[:, start:start + cl],
done=self.memory.done[:, start:start + cl],
logits=self.memory.logits[:, start:start + cl],
values=self.memory.values[:, start:start + cl])
return sample
def __len__(self):
return self.k
def __getitem__(self, i):
idx = random.choices(range(0, len(self.memory) - self.chunk_len), weights=self.whitelist, k=1)
return self.sample(idx[0])
class LazyTensorFiFoQueue:
def __init__(self, maxlen):
self.maxlen = maxlen
self.reset()
def reset(self):
self.__lazy_queue = deque(maxlen=self.maxlen)
self.shape = None
self.queue = None
def shape_init(self, tensor: Tensor):
self.shape = torch.Size([self.maxlen, *tensor.shape])
def build_tensor_queue(self):
if len(self.__lazy_queue) > 0:
block = torch.stack(list(self.__lazy_queue), dim=0)
l = block.shape[0]
if self.queue is None:
self.queue = block
elif self.true_len() <= self.maxlen:
self.queue = torch.cat((self.queue, block), dim=0)
else:
self.queue = torch.cat((self.queue[l:], block), dim=0)
self.__lazy_queue.clear()
def append(self, data):
if self.shape is None:
self.shape_init(data)
self.__lazy_queue.append(data)
if len(self.__lazy_queue) >= self.maxlen:
self.build_tensor_queue()
def true_len(self):
return len(self.__lazy_queue) + (0 if self.queue is None else self.queue.shape[0])
def __len__(self):
return min((self.true_len(), self.maxlen))
def __str__(self):
return f'LazyTensorFiFoQueue\tmaxlen: {self.maxlen}, shape: {self.shape}, ' \
f'len: {len(self)}, true_len: {self.true_len()}, elements in lazy queue: {len(self.__lazy_queue)}'
def __getitem__(self, item_or_slice):
self.build_tensor_queue()
return self.queue[item_or_slice]

View File

@@ -0,0 +1,103 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class RecurrentAC(nn.Module):
def __init__(self, observation_size, n_actions, obs_emb_size,
action_emb_size, hidden_size_actor, hidden_size_critic,
n_agents, use_agent_embedding=True):
super(RecurrentAC, self).__init__()
observation_size = np.prod(observation_size)
self.n_layers = 1
self.n_actions = n_actions
self.use_agent_embedding = use_agent_embedding
self.hidden_size_actor = hidden_size_actor
self.hidden_size_critic = hidden_size_critic
self.action_emb_size = action_emb_size
self.obs_proj = nn.Linear(observation_size, obs_emb_size)
self.action_emb = nn.Embedding(n_actions+1, action_emb_size, padding_idx=0)
self.agent_emb = nn.Embedding(n_agents, action_emb_size)
mix_in_size = obs_emb_size+action_emb_size if not use_agent_embedding else obs_emb_size+n_agents*action_emb_size
self.mix = nn.Sequential(nn.Tanh(),
nn.Linear(mix_in_size, obs_emb_size),
nn.Tanh(),
nn.Linear(obs_emb_size, obs_emb_size)
)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
self.action_head = nn.Sequential(
nn.Linear(hidden_size_actor, hidden_size_actor),
nn.Tanh(),
nn.Linear(hidden_size_actor, n_actions)
)
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
self.critic_head = nn.Sequential(
nn.Linear(hidden_size_critic, hidden_size_critic),
nn.Tanh(),
nn.Linear(hidden_size_critic, 1)
)
#self.action_head[-1].weight.data.uniform_(-3e-3, 3e-3)
#self.action_head[-1].bias.data.uniform_(-3e-3, 3e-3)
def init_hidden_actor(self):
return torch.zeros(1, self.n_layers, self.hidden_size_actor)
def init_hidden_critic(self):
return torch.zeros(1, self.n_layers, self.hidden_size_critic)
def forward(self, observations, actions, hidden_actor=None, hidden_critic=None):
n_agents, t, *_ = observations.shape
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
if not self.use_agent_embedding:
x_t = torch.cat((obs_emb, action_emb), -1)
else:
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
)
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
mixed_x_t = self.mix(x_t)
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
output_c, _ = self.gru_critic(input=mixed_x_t, hx=hidden_critic.swapaxes(1, 0))
logits = self.action_head(output_p)
critic = self.critic_head(output_c).squeeze(-1)
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
class RecurrentACL2(RecurrentAC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_head = nn.Sequential(
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
nn.Tanh(),
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
)
class NormalizedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int,
device=None, dtype=None, trainable_magnitude=False):
super(NormalizedLinear, self).__init__(in_features, out_features, False, device, dtype)
self.d_sqrt = in_features**0.5
self.trainable_magnitude = trainable_magnitude
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
def forward(self, in_array):
normalized_input = F.normalize(in_array, dim=-1, p=2, eps=1e-5)
normalized_weight = F.normalize(self.weight, dim=-1, p=2, eps=1e-5)
return F.linear(normalized_input, normalized_weight) * self.d_sqrt * self.scale
class L2Norm(nn.Module):
def __init__(self, in_features, trainable_magnitude=False):
super(L2Norm, self).__init__()
self.d_sqrt = in_features**0.5
self.scale = nn.Parameter(torch.tensor([1.]), requires_grad=trainable_magnitude)
def forward(self, x):
return F.normalize(x, dim=-1, p=2, eps=1e-5) * self.d_sqrt * self.scale

View File

@@ -0,0 +1,55 @@
import torch
from torch.distributions import Categorical
from marl_factory_grid.algorithms.rl.iac import LoopIAC
from marl_factory_grid.algorithms.rl.base_ac import nms
from marl_factory_grid.algorithms.rl.memory import MARLActorCriticMemory
class LoopSEAC(LoopIAC):
def __init__(self, cfg):
super(LoopSEAC, self).__init__(cfg)
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
losses = []
for ag_i, out in enumerate(outputs):
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
critic = out[nms.CRITIC]
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
# importance weights
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
losses.append(loss)
return losses
def learn(self, tms: MARLActorCriticMemory, **kwargs):
losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
for ag_i, loss in enumerate(losses):
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
self.optimizer[ag_i].step()

View File

@@ -0,0 +1,33 @@
from marl_factory_grid.algorithms.rl.base_ac import BaseActorCritic
from marl_factory_grid.algorithms.rl.base_ac import nms
import torch
from torch.distributions import Categorical
from pathlib import Path
class LoopSNAC(BaseActorCritic):
def __init__(self, cfg):
super().__init__(cfg)
def load_state_dict(self, path: Path):
path2weights = list(path.glob('*.pt'))
assert len(path2weights) == 1, f'Expected a single set of weights but got {len(path2weights)}'
self.net.load_state_dict(torch.load(path2weights[0]))
def init_hidden(self):
hidden_actor = self.net.init_hidden_actor()
hidden_critic = self.net.init_hidden_critic()
return dict(hidden_actor=torch.cat([hidden_actor] * self.n_agents, 0),
hidden_critic=torch.cat([hidden_critic] * self.n_agents, 0)
)
def get_actions(self, out):
actions = Categorical(logits=out[nms.LOGITS]).sample().squeeze()
return actions
def forward(self, observations, actions, hidden_actor, hidden_critic):
out = self.net(self._as_torch(observations).unsqueeze(1),
self._as_torch(actions).unsqueeze(1),
hidden_actor, hidden_critic
)
return out

View File

@@ -0,0 +1,337 @@
import copy
from typing import List
import numpy as np
import torch
from marl_factory_grid.algorithms.rl.constants import Names as nms
from marl_factory_grid.algorithms.rl.base_a2c import cumulate_discount
def _as_torch(x):
""" Helper function to convert different list types to a torch tensor """
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
elif isinstance(x, List):
return torch.tensor(x)
elif isinstance(x, (int, float)):
return torch.tensor([x])
return x
def transform_observations(env, ordered_coins, target_coin, cfg, n_agents):
""" Function that extracts local observations from global state
Requires that agents have observations -CoinPiles and -Self (cf. environment configs) """
agents_positions = get_agents_positions(env, n_agents)
coin_observability_is_all = cfg[nms.ALGORITHM][nms.PILE_OBSERVABILITY] == nms.ALL
if coin_observability_is_all:
trans_obs = [torch.zeros(2 + 2 * len(ordered_coins[0])) for _ in range(len(agents_positions))]
else:
# Only show current target pile
trans_obs = [torch.zeros(4) for _ in range(len(agents_positions))]
for i, pos in enumerate(agents_positions):
agent_x, agent_y = pos[0], pos[1]
trans_obs[i][0] = agent_x
trans_obs[i][1] = agent_y
idx = 2
if coin_observability_is_all:
for coin_pos in ordered_coins[i]:
trans_obs[i][idx] = coin_pos[0]
trans_obs[i][idx + 1] = coin_pos[1]
idx += 2
else:
trans_obs[i][2] = ordered_coins[i][target_coin[i]][0]
trans_obs[i][3] = ordered_coins[i][target_coin[i]][1]
return trans_obs
def get_all_observations(env, cfg, n_agents):
""" Helper function that returns all possible agent observations """
coins_positions = [env.state.entities[nms.COIN_PILES][pile_idx].pos for pile_idx in
range(len(env.state.entities[nms.COIN_PILES]))]
if cfg[nms.ALGORITHM][nms.PILE_OBSERVABILITY] == nms.ALL:
obs = [torch.zeros(2 + 2 * len(coins_positions))]
observations = [[]]
# Fill in pile positions
idx = 2
for pile_pos in coins_positions:
obs[0][idx] = pile_pos[0]
obs[0][idx + 1] = pile_pos[1]
idx += 2
else:
# Have multiple observation layers of the map for each coin pile one
obs = [torch.zeros(4) for _ in range(n_agents) for _ in coins_positions]
observations = [[] for _ in coins_positions]
for idx, pile_pos in enumerate(coins_positions):
obs[idx][2] = pile_pos[0]
obs[idx][3] = pile_pos[1]
valid_agent_positions = env.state.entities.floorlist
for idx, pos in enumerate(valid_agent_positions):
for obs_layer in range(len(obs)):
observation = copy.deepcopy(obs[obs_layer])
observation[0] = pos[0]
observation[1] = pos[1]
observations[obs_layer].append(observation)
return observations
def get_coin_piles_positions(env):
""" Get positions of coin piles on the map """
return [env.state.entities[nms.COIN_PILES][pile_idx].pos for pile_idx in
range(len(env.state.entities[nms.COIN_PILES]))]
def get_agents_positions(env, n_agents):
""" Get positions of agents on the map """
return [env.state.moving_entites[agent_idx].pos for agent_idx in range(n_agents)]
def get_ordered_coin_piles(env, collected_coins, cfg, n_agents):
""" This function determines in which order the agents should collect the coin piles
Each agent can have its individual pile order """
ordered_coin_piles = [[] for _ in range(n_agents)]
coin_piles_positions = get_coin_piles_positions(env)
agents_positions = get_agents_positions(env, n_agents)
for agent_idx in range(n_agents):
if cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.FIXED, nms.AGENTS]:
ordered_coin_piles[agent_idx] = coin_piles_positions
elif cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.SMART, nms.DYNAMIC]:
# Calculate distances for remaining unvisited coin piles
remaining_target_piles = [pos for pos, value in collected_coins[agent_idx].items() if not value]
pile_distances = {pos: 0 for pos in remaining_target_piles}
agent_pos = agents_positions[agent_idx]
for pos in remaining_target_piles:
pile_distances[pos] = np.abs(agent_pos[0] - pos[0]) + np.abs(agent_pos[1] - pos[1])
if cfg[nms.ALGORITHM][nms.PILE_ORDER] == nms.SMART:
# Check if there is an agent on the direct path to any of the remaining coin piles
for pile_pos in remaining_target_piles:
for other_pos in agents_positions:
if other_pos != agent_pos:
if agent_pos[0] == other_pos[0] == pile_pos[0] or agent_pos[1] == other_pos[1] == pile_pos[
1]:
# Get the line between the agent and the target
path = bresenham(agent_pos[0], agent_pos[1], pile_pos[0], pile_pos[1])
# Check if the entity lies on the path between the agent and the target
if other_pos in path:
pile_distances[pile_pos] += np.abs(agent_pos[0] - other_pos[0]) + np.abs(
agent_pos[1] - other_pos[1])
sorted_pile_distances = dict(sorted(pile_distances.items(), key=lambda item: item[1]))
# Insert already visited coin piles
ordered_coin_piles[agent_idx] = [pos for pos in coin_piles_positions if pos not in remaining_target_piles]
# Fill up with sorted positions
for pos in sorted_pile_distances.keys():
ordered_coin_piles[agent_idx].append(pos)
else:
print("Not a valid pile order option.")
exit()
return ordered_coin_piles
def bresenham(x0, y0, x1, y1):
"""Bresenham's line algorithm to get the coordinates of a line between two points."""
dx = np.abs(x1 - x0)
dy = np.abs(y1 - y0)
sx = 1 if x0 < x1 else -1
sy = 1 if y0 < y1 else -1
err = dx - dy
coordinates = []
while True:
coordinates.append((x0, y0))
if x0 == x1 and y0 == y1:
break
e2 = 2 * err
if e2 > -dy:
err -= dy
x0 += sx
if e2 < dx:
err += dx
y0 += sy
return coordinates
def update_ordered_coin_piles(agent_idx, collected_coin_piles, ordered_coin_piles, env, cfg, n_agents):
""" Update the order of the remaining coin piles """
# Only update ordered_coin_pile for agent that reached its target pile
updated_ordered_coin_piles = get_ordered_coin_piles(env, collected_coin_piles, cfg, n_agents)
for i in range(len(ordered_coin_piles[agent_idx])):
ordered_coin_piles[agent_idx][i] = updated_ordered_coin_piles[agent_idx][i]
def distribute_indices(env, cfg, n_agents):
""" Distribute coin piles evenly among the agents """
indices = []
n_coin_piles = len(get_coin_piles_positions(env))
agents_positions = get_agents_positions(env, n_agents)
if n_coin_piles == 1 or cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.FIXED, nms.DYNAMIC, nms.SMART]:
indices = [[0] for _ in range(n_agents)]
else:
base_count = n_coin_piles // n_agents
remainder = n_coin_piles % n_agents
start_index = 0
for i in range(n_agents):
# Add an extra index to the first 'remainder' objects
end_index = start_index + base_count + (1 if i < remainder else 0)
indices.append(list(range(start_index, end_index)))
start_index = end_index
# Static form: auxiliary pile, primary pile, auxiliary pile, ...
# -> Starting with index 0 even piles are auxiliary piles, odd piles are primary piles
if cfg[nms.ALGORITHM][nms.AUXILIARY_PILES] and nms.DOORS in env.state.entities.keys():
door_positions = [door.pos for door in env.state.entities[nms.DOORS]]
distances = {door_pos: [] for door_pos in door_positions}
# Calculate distance of every agent to every door
for door_pos in door_positions:
for agent_pos in agents_positions:
distances[door_pos].append(np.abs(door_pos[0] - agent_pos[0]) + np.abs(door_pos[1] - agent_pos[1]))
def duplicate_indices(lst, item):
return [i for i, x in enumerate(lst) if x == item]
# Get agent indices of agents with same distance to door
affected_agents = {door_pos: {} for door_pos in door_positions}
for door_pos in distances.keys():
dist = distances[door_pos]
dist_set = set(dist)
for d in dist_set:
affected_agents[door_pos][str(d)] = duplicate_indices(dist, d)
updated_indices = []
for door_pos, agent_distances in affected_agents.items():
if len(agent_distances) == 0:
# Remove auxiliary piles for all agents
# (In config, we defined every pile with an even numbered index to be an auxiliary pile)
updated_indices = [[ele for ele in lst if ele % 2 != 0] for lst in indices]
else:
for distance, agent_indices in agent_distances.items():
# For each distance group, pick one random agent to keep the auxiliary pile
# selected_agent = np.random.choice(agent_indices)
selected_agent = 0
for agent_idx in agent_indices:
if agent_idx == selected_agent:
updated_indices.append(indices[agent_idx])
else:
updated_indices.append([ele for ele in indices[agent_idx] if ele % 2 != 0])
indices = updated_indices
return indices
def update_target_pile(env, agent_idx, target_pile, indices, cfg):
""" Get the next target pile for a given agent """
if cfg[nms.ALGORITHM][nms.PILE_ORDER] in [nms.FIXED, nms.DYNAMIC, nms.SMART]:
if target_pile[agent_idx] + 1 < len(get_coin_piles_positions(env)):
target_pile[agent_idx] += 1
else:
target_pile[agent_idx] = 0
else:
if target_pile[agent_idx] + 1 in indices[agent_idx]:
target_pile[agent_idx] += 1
def is_door_close(env, agent_idx):
""" Checks whether the agent is close to a door """
neighbourhood = [y for x in env.state.entities.neighboring_positions(env.state[nms.AGENT][agent_idx].pos)
for y in env.state.entities.pos_dict[x] if nms.DOOR in y.name]
if neighbourhood:
return neighbourhood[0]
def get_all_collected_coin_piles(coin_piles_positions, collected_coin_piles, n_agents):
""" Returns all coin piles collected by any agent """
meta_collected_coin_piles = {pos: False for pos in coin_piles_positions}
for agent_idx in range(n_agents):
for (pos, collected) in collected_coin_piles[agent_idx].items():
if collected:
meta_collected_coin_piles[pos] = True
return meta_collected_coin_piles
def handle_finished_episode(obs, agents, cfg):
""" Finish up episode, calculate advantages and perform policy net and value net updates"""
with torch.inference_mode(False):
for ag_i, agent in enumerate(agents):
# Get states, actions, rewards and values from rollout buffer
data = agent.finish_episode()
# Chunk episode data, such that there will be no memory failure for very long episodes
chunks = split_into_chunks(data, cfg)
for (s, a, R, V) in chunks:
# Calculate discounted return and advantage
G = cumulate_discount(R, cfg[nms.ALGORITHM][nms.GAMMA])
if cfg[nms.ALGORITHM][nms.ADVANTAGE] == nms.REINFORCE:
A = G
elif cfg[nms.ALGORITHM][nms.ADVANTAGE] == nms.ADVANTAGE_AC:
A = G - V # Actor-Critic Advantages
elif cfg[nms.ALGORITHM][nms.ADVANTAGE] == nms.TD_ADVANTAGE_AC:
with torch.no_grad():
A = R + cfg[nms.ALGORITHM][nms.GAMMA] * np.append(V[1:], agent.vf(
_as_torch(obs[ag_i]).view(-1).to(
torch.float32)).numpy()) - V # TD Actor-Critic Advantages
else:
print("Not a valid advantage option.")
exit()
rollout = (torch.tensor(x.copy()).to(torch.float32) for x in (s, a, G, A))
# Update policy and value net of agent with experience from rollout buffer
agent.train(*rollout)
def split_into_chunks(data_tuple, cfg):
""" Chunks episode data into approximately equal sized chunks to prevent system memory failure from overload """
result = [data_tuple]
chunk_size = cfg[nms.ALGORITHM][nms.CHUNK_EPISODE]
if chunk_size > 0:
# Get the maximum length of the lists in the tuple to handle different lengths
max_length = max(len(lst) for lst in data_tuple)
# Prepare a list to store the result
result = []
# Split each list into chunks and add them to the result
for i in range(0, max_length, chunk_size):
# Create a sublist containing the ith chunk from each list
sublist = [lst[i:i + chunk_size] for lst in data_tuple if i < len(lst)]
result.append(sublist)
return result
def set_agents_spawnpoints(env, n_agents):
""" Tell environment where the agents should spawn in the next episode """
for agent_idx in range(n_agents):
agent_name = list(env.state.agents_conf.keys())[agent_idx]
current_pos_pointer = env.state.agents_conf[agent_name][nms.POS_POINTER]
# Making the reset dependent on the number of spawnpoints and not the number of coinpiles allows
# for having multiple subsequent spawnpoints with the same target pile
if current_pos_pointer == len(env.state.agents_conf[agent_name][nms.POSITIONS]) - 1:
env.state.agents_conf[agent_name][nms.POS_POINTER] = 0
else:
env.state.agents_conf[agent_name][nms.POS_POINTER] += 1
def save_configs(results_path, cfg, factory_conf, eval_factory_conf):
""" Save configurations for logging purposes """
with open(f"{results_path}/MARL_config.txt", "w") as txt_file:
txt_file.write(str(cfg))
with open(f"{results_path}/train_env_config.txt", "w") as txt_file:
txt_file.write(str(factory_conf))
with open(f"{results_path}/eval_env_config.txt", "w") as txt_file:
txt_file.write(str(eval_factory_conf))
def save_agent_models(results_path, agents):
""" Save model parameters after training """
for idx, agent in enumerate(agents):
agent.pi.save_model_parameters(results_path)
agent.vf.save_model_parameters(results_path)