add more efficient (lazy) experience queue implementation based on tensor, adjusted marl algorithms

This commit is contained in:
Robert Müller
2022-02-03 13:14:48 +01:00
parent b09c461754
commit a9a4274370
8 changed files with 243 additions and 165 deletions

View File

@ -1,6 +1,7 @@
import torch
from torch.distributions import Categorical
from algorithms.marl.iac import LoopIAC
from algorithms.marl.base_ac import nms
from algorithms.marl.memory import MARLActorCriticMemory
@ -9,12 +10,12 @@ class LoopSEAC(LoopIAC):
super(LoopSEAC, self).__init__(cfg)
def actor_critic(self, tm, networks, gamma, entropy_coef, vf_coef, gae_coef=0.0, **kwargs):
obs, actions, done, reward = tm.observation, tm.action, tm.done, tm.reward
outputs = [net(obs, actions, tm.hidden_actor, tm.hidden_critic) for net in networks]
obs, actions, done, reward = tm.observation, tm.action, tm.done[:, 1:], tm.reward[:, 1:]
outputs = [net(obs, actions, tm.hidden_actor[:, 0], tm.hidden_critic[:, 0]) for net in networks]
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out['logits'][ag_i, :-1], -1)
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
@ -22,8 +23,8 @@ class LoopSEAC(LoopIAC):
losses = []
for ag_i, out in enumerate(outputs):
logits = out['logits'][:, :-1] # last one only needed for v_{t+1}
critic = out['critic']
logits = out[nms.LOGITS][:, :-1] # last one only needed for v_{t+1}
critic = out[nms.CRITIC]
entropy_loss = Categorical(logits=logits[ag_i]).entropy().mean()
advantages = self.compute_advantages(critic, reward, done, gamma, gae_coef)
@ -47,7 +48,7 @@ class LoopSEAC(LoopIAC):
return losses
def learn(self, tms: MARLActorCriticMemory, **kwargs):
losses = self.actor_critic(tms, self.net, **self.cfg['algorithm'], **kwargs)
losses = self.actor_critic(tms, self.net, **self.cfg[nms.ALGORITHM], **kwargs)
for ag_i, loss in enumerate(losses):
self.optimizer[ag_i].zero_grad()
loss.backward()