mirror of
https://github.com/illiumst/marl-factory-grid.git
synced 2025-06-18 18:52:52 +02:00
refactored algorithms
This commit is contained in:
23
algorithms/vdn_learner.py
Normal file
23
algorithms/vdn_learner.py
Normal file
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
from algorithms.q_learner import QLearner
|
||||
|
||||
|
||||
class VDNLearner(QLearner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(VDNLearner, self).__init__(*args, **kwargs)
|
||||
assert self.n_agents >= 2, 'VDN requires more than one agent, use QLearner instead'
|
||||
|
||||
def train(self):
|
||||
if len(self.buffer) < self.batch_size: return
|
||||
for _ in range(self.n_grad_steps):
|
||||
experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps)
|
||||
pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1))
|
||||
for agent_i in range(self.n_agents):
|
||||
q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i],
|
||||
experience.next_observation[:, agent_i],
|
||||
experience.action[:, agent_i].unsqueeze(-1))
|
||||
pred_q += q_values
|
||||
target_q_raw += next_q_values_raw
|
||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
||||
self._backprop_loss(loss)
|
Reference in New Issue
Block a user