Resolved some warnings and style issues

This commit is contained in:
Steffen Illium
2023-11-10 09:29:54 +01:00
parent a9462a8b6f
commit 6711a0976b
64 changed files with 331 additions and 361 deletions

View File

@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC):
with torch.inference_mode(True):
true_action_logp = torch.stack([
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
.gather(index=actions[ag_i, 1:, None], dim=-1)
for ag_i, out in enumerate(outputs)
], 0).squeeze()
@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
# weighted loss
@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC):
self.optimizer[ag_i].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
self.optimizer[ag_i].step()
self.optimizer[ag_i].step()