2021-06-25 15:42:55 +02:00

48 lines
2.5 KiB
Python

import torch
from algorithms.q_learner import QLearner
class QTRANLearner(QLearner):
def __init__(self, *args, weight_opt=1., weigt_nopt=1., **kwargs):
super(QTRANLearner, self).__init__(*args, **kwargs)
assert self.n_agents >= 2, 'QTRANLearner requires more than one agent, use QLearner instead'
self.weight_opt = weight_opt
self.weigt_nopt = weigt_nopt
def _training_routine(self, obs, next_obs, action):
# todo remove - is inherited - only used while implementing qtran
current_q_values = self.q_net(obs)
current_q_values = torch.gather(current_q_values, dim=-1, index=action)
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
return current_q_values, next_q_values_raw
def local_qs(self, observations, actions):
Q_jt = torch.zeros_like(actions) # placeholder to sum up individual q values
features = []
for agent_i in range(self.n_agents):
q_values_agent_i, features_agent_i = self.q_net(observations[:, agent_i]) # Individual action-value network
q_values_agent_i = torch.gather(q_values_agent_i, dim=-1, index=actions[:, agent_i].unsqueeze(-1))
Q_jt += q_values_agent_i
features.append(features_agent_i)
feature_sum = torch.stack(features, 0).sum(0) # (n_agents x hdim) -> hdim
return Q_jt
def train(self):
if len(self.buffer) < self.batch_size: return
for _ in range(self.n_grad_steps):
experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps)
Q_jt_prime = self.local_qs(experience.observation, experience.action) # sum of individual q-vals
Q_jt = None
V_jt = None
pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1))
for agent_i in range(self.n_agents):
q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i],
experience.next_observation[:, agent_i],
experience.action[:, agent_i].unsqueeze(-1))
pred_q += q_values
target_q_raw += next_q_values_raw
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
self._backprop_loss(loss)