2022-01-28 11:07:25 +01:00

79 lines
3.5 KiB
Python

from algorithms.marl import LoopSNAC
from algorithms.marl.memory import MARLActorCriticMemory
from typing import List
import random
import torch
from torch.distributions import Categorical
class LoopMAPPO(LoopSNAC):
def __init__(self, *args, **kwargs):
super(LoopMAPPO, self).__init__(*args, **kwargs)
def build_batch(self, tm: List[MARLActorCriticMemory]):
sample = random.choices(tm, k=self.cfg['algorithm']['batch_size']-1)
sample.append(tm[-1]) # always use latest segment in batch
obs = torch.cat([s.observation for s in sample], 0)
actions = torch.cat([s.action for s in sample], 0)
hidden_actor = torch.cat([s.hidden_actor for s in sample], 0)
hidden_critic = torch.cat([s.hidden_critic for s in sample], 0)
logits = torch.cat([s.logits for s in sample], 0)
values = torch.cat([s.values for s in sample], 0)
reward = torch.cat([s.reward for s in sample], 0)
done = torch.cat([s.done for s in sample], 0)
log_props = torch.log_softmax(logits, -1)
log_props = torch.gather(log_props, index=actions[:, 1:].unsqueeze(-1), dim=-1).squeeze()
return obs, actions, hidden_actor, hidden_critic, log_props, values, reward, done
def learn(self, tm: List[MARLActorCriticMemory], **kwargs):
if len(tm) >= self.cfg['algorithm']['keep_n_segments']:
# only learn when buffer is full
for batch_i in range(self.cfg['algorithm']['n_updates']):
loss = self.actor_critic(tm, self.net, **self.cfg['algorithm'], **kwargs)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 0.5)
self.optimizer.step()
def monte_carlo_returns(self, rewards, done, gamma):
rewards_ = []
discounted_reward = torch.zeros_like(rewards[:, -1])
for t in range(rewards.shape[1]-1, -1, -1):
discounted_reward = rewards[:, t] + (gamma * (1.0 - done[:, t]) * discounted_reward)
rewards_.insert(0, discounted_reward)
rewards_ = torch.stack(rewards_, dim=1)
return rewards_
def actor_critic(self, tm, network, gamma, entropy_coef, vf_coef, clip_range, gae_coef=0.0, **kwargs):
obs, actions, hidden_actor, hidden_critic, old_log_probs, old_critic, reward, done = self.build_batch(tm)
out = network(obs, actions, hidden_actor, hidden_critic)
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
critic = out['critic']
# monte carlo returns
mc_returns = self.monte_carlo_returns(reward, done, gamma)
# monte_carlo_returns = (mc_returns - mc_returns.mean()) / (mc_returns.std() + 1e-7) todo: norm across agents?
advantages = mc_returns - critic[:, :-1]
# policy loss
log_ap = torch.log_softmax(logits, -1)
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
ratio = (log_ap - old_log_probs).exp()
surr1 = ratio * advantages.detach()
surr2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages.detach()
policy_loss = -torch.min(surr1, surr2).mean(-1)
# entropy & value loss
entropy_loss = Categorical(logits=logits).entropy().mean(-1)
value_loss = advantages.pow(2).mean(-1) # n_agent
# weighted loss
loss = policy_loss + vf_coef*value_loss - entropy_coef * entropy_loss
return loss.mean()