added first working MAPPO implementation

This commit is contained in:
Robert Müller
2022-01-28 11:07:25 +01:00
parent ffc47752a7
commit b09c461754
11 changed files with 194 additions and 61 deletions

View File

@ -1,4 +1,6 @@
from algorithms.marl.base_ac import BaseActorCritic
from algorithms.marl.iac import LoopIAC
from algorithms.marl.snac import LoopSNAC
from algorithms.marl.seac import LoopSEAC
from algorithms.marl.iac import LoopIAC
from algorithms.marl.snac import LoopSNAC
from algorithms.marl.seac import LoopSEAC
from algorithms.marl.mappo import LoopMAPPO
from algorithms.marl.memory import MARLActorCriticMemory

View File

@ -1,5 +1,6 @@
import torch
from typing import Union, List
import copy
import numpy as np
from torch.distributions import Categorical
from algorithms.marl.memory import MARLActorCriticMemory
@ -59,7 +60,7 @@ class BaseActorCritic:
actions: ListOrTensor,
hidden_actor: ListOrTensor,
hidden_critic: ListOrTensor
):
) -> dict[ListOrTensor]:
pass
@ -67,8 +68,9 @@ class BaseActorCritic:
def train_loop(self, checkpointer=None):
env = instantiate_class(self.cfg['env'])
n_steps, max_steps = [self.cfg['algorithm'][k] for k in ['n_steps', 'max_steps']]
global_steps = 0
global_steps, episode, df_results = 0, 0, []
reward_queue = deque(maxlen=2000)
memory_queue = deque(maxlen=self.cfg['algorithm'].get('keep_n_segments', 1))
while global_steps < max_steps:
tm = MARLActorCriticMemory(self.n_agents)
obs = env.reset()
@ -85,7 +87,8 @@ class BaseActorCritic:
next_obs = next_obs
if isinstance(done, bool): done = [done] * self.n_agents
tm.add(observation=obs, action=action, reward=reward, done=done)
tm.add(observation=obs, action=action, reward=reward, done=done,
logits=out.get('logits', None), values=out.get('critic', None))
obs = next_obs
last_action = action
last_hiddens = dict(hidden_actor=out.get('hidden_actor', None),
@ -94,9 +97,11 @@ class BaseActorCritic:
if len(tm) >= n_steps or all(done):
tm.add(observation=next_obs)
memory_queue.append(copy.deepcopy(tm))
if self.__training:
with torch.inference_mode(False):
self.learn(tm)
tm_ = tm if memory_queue.maxlen <= 1 else list(memory_queue)
self.learn(tm_)
tm.reset()
tm.add(action=last_action, **last_hiddens)
global_steps += 1
@ -110,7 +115,13 @@ class BaseActorCritic:
])
if global_steps >= max_steps: break
print(f'reward at step: {global_steps} = {rew_log}')
print(f'reward at step: {episode} = {rew_log}')
episode += 1
df_results.append([global_steps, rew_log])
df_results = pd.DataFrame(df_results, columns=['steps', 'reward'])
if checkpointer is not None:
df_results.to_csv(checkpointer.path / 'results.csv', index=False)
return df_results
@torch.inference_mode(True)
def eval_loop(self, n_episodes, render=False):
@ -143,10 +154,21 @@ class BaseActorCritic:
return results
@staticmethod
def compute_advantages(critic, reward, done, gamma):
return (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
def compute_advantages(critic, reward, done, gamma, gae_coef=0.0):
tds = (reward + gamma * (1.0 - done) * critic[:, 1:].detach()) - critic[:, :-1]
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, **kwargs):
if gae_coef <= 0:
return tds
gae = torch.zeros_like(tds[:, -1])
gaes = []
for t in range(tds.shape[1]-1, -1, -1):
gae = tds[:, t] + gamma * gae_coef * (1.0 - done[:, t]) * gae
gaes.insert(0, gae)
gaes = torch.stack(gaes, dim=1)
return gaes
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
out = network(obs, actions, tm.hidden_actor, tm.hidden_critic)
@ -154,7 +176,7 @@ class BaseActorCritic:
critic = out['critic']
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
advantages = self.compute_advantages(critic, reward, done, gamma)
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
value_loss = advantages.pow(2).mean(-1) # n_agent
# policy loss
@ -163,7 +185,6 @@ class BaseActorCritic:
a2c_loss = -(advantages.detach() * log_ap).mean(-1)
# weighted loss
loss = a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()
def learn(self, tm: MARLActorCriticMemory, **kwargs):

View File

@ -21,7 +21,6 @@ class LoopIAC(BaseActorCritic):
def load_state_dict(self, path: Path):
paths = natsorted(list(path.glob('*.pt')))
print(list(paths))
for path, net in zip(paths, self.net):
net.load_state_dict(torch.load(path))

78
algorithms/marl/mappo.py Normal file
View File

@ -0,0 +1,78 @@
from algorithms.marl import LoopSNAC
from algorithms.marl.memory import MARLActorCriticMemory
from typing import List
import random
import torch
from torch.distributions import Categorical
class LoopMAPPO(LoopSNAC):
def __init__(self, *args, **kwargs):
super(LoopMAPPO, self).__init__(*args, **kwargs)
def build_batch(self, tm: List[MARLActorCriticMemory]):
sample = random.choices(tm, k=self.cfg['algorithm']['batch_size']-1)
sample.append(tm[-1]) # always use latest segment in batch
obs = torch.cat([s.observation for s in sample], 0)
actions = torch.cat([s.action for s in sample], 0)
hidden_actor = torch.cat([s.hidden_actor for s in sample], 0)
hidden_critic = torch.cat([s.hidden_critic for s in sample], 0)
logits = torch.cat([s.logits for s in sample], 0)
values = torch.cat([s.values for s in sample], 0)
reward = torch.cat([s.reward for s in sample], 0)
done = torch.cat([s.done for s in sample], 0)
log_props = torch.log_softmax(logits, -1)
log_props = torch.gather(log_props, index=actions[:, 1:].unsqueeze(-1), dim=-1).squeeze()
return obs, actions, hidden_actor, hidden_critic, log_props, values, reward, done
def learn(self, tm: List[MARLActorCriticMemory], **kwargs):
if len(tm) >= self.cfg['algorithm']['keep_n_segments']:
# only learn when buffer is full
for batch_i in range(self.cfg['algorithm']['n_updates']):
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()
def monte_carlo_returns(self, rewards, done, gamma):
rewards_ = []
discounted_reward = torch.zeros_like(rewards[:, -1])
for t in range(rewards.shape[1]-1, -1, -1):
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
rewards_.insert(0, discounted_reward)
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, clip_range, gae_coef=0.0, **kwargs):
obs, actions, hidden_actor, hidden_critic, old_log_probs, old_critic, reward, done = self.build_batch(tm)
out = network(obs, actions, hidden_actor, hidden_critic)
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
critic = out['critic']
# monte carlo returns
mc_returns = self.monte_carlo_returns(reward, done, gamma)
# monte_carlo_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-7) todo: norm across agents?
advantages = mc_returns - critic[:, :-1]
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
ratio = (log_ap - old_log_probs).exp()
surr1 = ratio * advantages.detach()
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
policy_loss = -torch.min(surr1, surr2).mean(-1)
# entropy & value loss
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
value_loss = advantages.pow(2).mean(-1) # n_agent
# weighted loss
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()

View File

@ -13,14 +13,16 @@ class ActorCriticMemory(object):
self.__actions = []
self.__rewards = []
self.__dones = []
self.__hiddens_actor = []
self.__hiddens_actor = []
self.__hiddens_critic = []
self.__logits = []
self.__values = []
def __len__(self):
return len(self.__states)
@property
def observation(self):
def observation(self): # add time dimension through stacking
return torch.stack(self.__states, 0).unsqueeze(0) # 1 x timesteps x hidden dim
@property
@ -47,6 +49,14 @@ class ActorCriticMemory(object):
def done(self):
return torch.tensor(self.__dones).float().unsqueeze(0) # 1 x timesteps
@property
def logits(self): # assumes a trailing 1 for time dimension - common when using output from NN
return torch.cat(self.__logits, 0).unsqueeze(0) # 1 x timesteps x actions
@property
def values(self):
return torch.cat(self.__values, 0).unsqueeze(0) # 1 x timesteps x actions
def add_observation(self, state: Union[Tensor, np.ndarray]):
self.__states.append(state if isinstance(state, Tensor) else torch.from_numpy(state))
@ -69,6 +79,12 @@ class ActorCriticMemory(object):
def add_done(self, done: bool):
self.__dones.append(done)
def add_logits(self, logits: Tensor):
self.__logits.append(logits)
def add_values(self, logits: Tensor):
self.__values.append(logits)
def add(self, **kwargs):
for k, v in kwargs.items():
func = getattr(ActorCriticMemory, f'add_{k}')
@ -129,3 +145,14 @@ class MARLActorCriticMemory(object):
all_hc = [mem.hidden_critic for mem in self.memories]
return torch.cat(all_hc, 0) # agents x layers x timesteps x hidden dim
@property
def logits(self):
all_lgts = [mem.logits for mem in self.memories]
return torch.cat(all_lgts, 0) # agents x layers x timesteps x hidden dim
@property
def values(self):
all_vals = [mem.values for mem in self.memories]
return torch.cat(all_vals, 0) # agents x layers x timesteps x hidden dim

View File

@ -12,6 +12,7 @@ class RecurrentAC(nn.Module):
super(RecurrentAC, self).__init__()
observation_size = np.prod(observation_size)
self.n_layers = 1
self.n_actions = n_actions
self.use_agent_embedding = use_agent_embedding
self.hidden_size_actor = hidden_size_actor
self.hidden_size_critic = hidden_size_critic
@ -25,13 +26,14 @@ class RecurrentAC(nn.Module):
nn.Tanh(),
nn.Linear(obs_emb_size, obs_emb_size)
)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_actor = nn.GRU(obs_emb_size, hidden_size_actor, batch_first=True, num_layers=self.n_layers)
self.gru_critic = nn.GRU(obs_emb_size, hidden_size_critic, batch_first=True, num_layers=self.n_layers)
self.action_head = nn.Sequential(
spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
nn.Linear(hidden_size_actor, hidden_size_actor),
nn.Tanh(),
nn.Linear(hidden_size_actor, n_actions)
)
# spectral_norm(nn.Linear(hidden_size_actor, hidden_size_actor)),
self.critic_head = nn.Sequential(
nn.Linear(hidden_size_critic, hidden_size_critic),
nn.Tanh(),
@ -50,12 +52,14 @@ class RecurrentAC(nn.Module):
n_agents, t, *_ = observations.shape
obs_emb = self.obs_proj(observations.view(n_agents, t, -1).float())
action_emb = self.action_emb(actions+1) # shift by one due to padding idx
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)]*t, 1)
)
x_t = torch.cat((obs_emb, action_emb), -1) \
if not self.use_agent_embedding else torch.cat((obs_emb, agent_emb, action_emb), -1)
if not self.use_agent_embedding:
x_t = torch.cat((obs_emb, action_emb), -1)
else:
agent_emb = self.agent_emb(
torch.cat([torch.arange(0, n_agents, 1).view(-1, 1)] * t, 1)
)
x_t = torch.cat((obs_emb, agent_emb, action_emb), -1)
mixed_x_t = self.mix(x_t)
output_p, _ = self.gru_actor(input=mixed_x_t, hx=hidden_actor.swapaxes(1, 0))
@ -66,6 +70,15 @@ class RecurrentAC(nn.Module):
return dict(logits=logits, critic=critic, hidden_actor=output_p, hidden_critic=output_c)
class RecurrentACL2(RecurrentAC):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.action_head = nn.Sequential(
nn.Linear(self.hidden_size_actor, self.hidden_size_actor),
nn.Tanh(),
NormalizedLinear(self.hidden_size_actor, self.n_actions, trainable_magnitude=True)
)
class NormalizedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int,

View File

@ -8,7 +8,7 @@ class LoopSEAC(LoopIAC):
def __init__(self, cfg):
super(LoopSEAC, self).__init__(cfg)
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, **kwargs):
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks]
@ -26,7 +26,7 @@ class LoopSEAC(LoopIAC):
critic = out['critic']
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
advantages = self.compute_advantages(critic, reward, done, gamma)
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
# policy loss
log_ap = torch.log_softmax(logits, -1)