55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
import torch
|
|
from torch.distributions import Categorical
|
|
from algorithms.marl.iac import LoopIAC
|
|
from algorithms.marl.memory import MARLActorCriticMemory
|
|
|
|
|
|
class LoopSEAC(LoopIAC):
|
|
def __init__(self, cfg):
|
|
super(LoopSEAC, self).__init__(cfg)
|
|
|
|
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
|
|
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
|
|
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks]
|
|
|
|
with torch.inference_mode(True):
|
|
true_action_logp = torch.stack([
|
|
torch.log_softmax(out['logits'][ag_i, :-1], -1)
|
|
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
|
for ag_i, out in enumerate(outputs)
|
|
], 0).squeeze()
|
|
|
|
losses = []
|
|
|
|
for ag_i, out in enumerate(outputs):
|
|
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
|
|
critic = out['critic']
|
|
|
|
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
|
|
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
|
|
|
|
# policy loss
|
|
log_ap = torch.log_softmax(logits, -1)
|
|
log_ap = torch.gather(log_ap, dim=-1, index=actions[:, 1:].unsqueeze(-1)).squeeze()
|
|
|
|
# importance weights
|
|
iw = (log_ap - true_action_logp).exp().detach() # importance_weights
|
|
|
|
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
|
|
|
|
|
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
|
|
|
# weighted loss
|
|
loss = (a2c_loss + vf_coef*value_loss - entropy_coef * entropy_loss).mean()
|
|
losses.append(loss)
|
|
|
|
return losses
|
|
|
|
def learn(self, tms: MARLActorCriticMemory, **kwargs):
|
|
losses = self.actor_critic(tms, self.net, **self.cfg['algorithm'], **kwargs)
|
|
for ag_i, loss in enumerate(losses):
|
|
self.optimizer[ag_i].zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
|
|
self.optimizer[ag_i].step() |