mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-07-05 17:11:35 +02:00
Resolved some warnings and style issues
This commit is contained in:
@ -16,7 +16,7 @@ class LoopSEAC(LoopIAC):
|
||||
with torch.inference_mode(True):
|
||||
true_action_logp = torch.stack([
|
||||
torch.log_softmax(out[nms.LOGITS][ag_i, :-1], -1)
|
||||
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
||||
.gather(index=actions[ag_i, 1:, None], dim=-1)
|
||||
for ag_i, out in enumerate(outputs)
|
||||
], 0).squeeze()
|
||||
|
||||
@ -38,7 +38,6 @@ class LoopSEAC(LoopIAC):
|
||||
|
||||
a2c_loss = (-iw*log_ap * advantages.detach()).mean(-1)
|
||||
|
||||
|
||||
value_loss = (iw*advantages.pow(2)).mean(-1) # n_agent
|
||||
|
||||
# weighted loss
|
||||
@ -53,4 +52,4 @@ class LoopSEAC(LoopIAC):
|
||||
self.optimizer[ag_i].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.net[ag_i].parameters(), 0.5)
|
||||
self.optimizer[ag_i].step()
|
||||
self.optimizer[ag_i].step()
|
||||
|
Reference in New Issue
Block a user