From 42f0dde05643937dd766fd068cb075b48c59a5d3 Mon Sep 17 00:00:00 2001 From: romue Date: Wed, 23 Jun 2021 10:56:18 +0200 Subject: [PATCH] refactored algorithms --- algorithms/common.py | 102 ++++++++++++++ algorithms/dqn.py | 277 -------------------------------------- algorithms/m_q_learner.py | 53 ++++++++ algorithms/q_learner.py | 144 ++++++++++++++++++++ algorithms/vdn_learner.py | 23 ++++ 5 files changed, 322 insertions(+), 277 deletions(-) create mode 100644 algorithms/common.py delete mode 100644 algorithms/dqn.py create mode 100644 algorithms/m_q_learner.py create mode 100644 algorithms/q_learner.py create mode 100644 algorithms/vdn_learner.py diff --git a/algorithms/common.py b/algorithms/common.py new file mode 100644 index 0000000..2c2f678 --- /dev/null +++ b/algorithms/common.py @@ -0,0 +1,102 @@ +from typing import NamedTuple, Union +from collections import deque, OrderedDict +import numpy as np +import random +import torch +import torch.nn as nn + + +class BaseLearner: + def __init__(self, env, n_agents, lr): + self.env = env + self.n_agents = n_agents + self.lr = lr + self.device = 'cpu' + + def to(self, device): + self.device = device + for attr, value in self.__dict__.items(): + if isinstance(value, nn.Module): + value = value.to(self.device) + return self + + +class Experience(NamedTuple): + # can be use for a single (s_t, a, r s_{t+1}) tuple + # or for a batch of tuples + observation: np.ndarray + next_observation: np.ndarray + action: np.ndarray + reward: Union[float, np.ndarray] + done : Union[bool, np.ndarray] + + +class BaseBuffer: + def __init__(self, size: int): + self.size = size + self.experience = deque(maxlen=size) + + def __len__(self): + return len(self.experience) + + def add(self, experience): + self.experience.append(experience) + + def sample(self, k, cer=4): + sample = random.choices(self.experience, k=k-cer) + for i in range(cer): sample += [self.experience[-i]] + observations = torch.stack([torch.from_numpy(e.observation) for e in sample], 0).float() + next_observations = torch.stack([torch.from_numpy(e.next_observation) for e in sample], 0).float() + actions = torch.tensor([e.action for e in sample]).long() + rewards = torch.tensor([e.reward for e in sample]).float().view(-1, 1) + dones = torch.tensor([e.done for e in sample]).float().view(-1, 1) + return Experience(observations, next_observations, actions, rewards, dones) + + +def soft_update(local_model, target_model, tau): + # taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data) + + +def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity'): + activations = {'elu': nn.ELU, 'relu': nn.ReLU, + 'leaky_relu': nn.LeakyReLU, 'tanh': nn.Tanh, + 'gelu': nn.GELU, 'identity': nn.Identity} + layers = [('Flatten', nn.Flatten())] if flatten else [] + for i in range(1, len(dims)): + layers.append((f'Layer #{i - 1}: Linear', nn.Linear(dims[i - 1], dims[i]))) + activation_str = activation if i != len(dims)-1 else activation_last + layers.append((f'Layer #{i - 1}: {activation_str.capitalize()}', activations[activation_str]())) + return nn.Sequential(OrderedDict(layers)) + + +class BaseDQN(nn.Module): + def __init__(self, dims=[3*5*5, 64, 64, 9]): + super(BaseDQN, self).__init__() + self.net = mlp_maker(dims, flatten=True) + + @torch.no_grad() + def act(self, x) -> np.ndarray: + action = self.forward(x).max(-1)[1].numpy() + return action + + def forward(self, x): + return self.net(x) + + +class BaseDDQN(BaseDQN): + def __init__(self, + backbone_dims=[3*5*5, 64, 64], + value_dims=[64, 1], + advantage_dims=[64, 9]): + super(BaseDDQN, self).__init__(backbone_dims) + self.net = mlp_maker(backbone_dims, flatten=True) + self.value_head = mlp_maker(value_dims) + self.advantage_head = mlp_maker(advantage_dims) + + def forward(self, x): + features = self.net(x) + advantages = self.advantage_head(features) + values = self.value_head(features) + return values + (advantages - advantages.mean()) diff --git a/algorithms/dqn.py b/algorithms/dqn.py deleted file mode 100644 index 7ce2d20..0000000 --- a/algorithms/dqn.py +++ /dev/null @@ -1,277 +0,0 @@ -from typing import NamedTuple, Union -from collections import deque, OrderedDict -import numpy as np -import random -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Experience(NamedTuple): - # can be use for a single (s_t, a, r s_{t+1}) tuple - # or for a batch of tuples - observation: np.ndarray - next_observation: np.ndarray - action: np.ndarray - reward: Union[float, np.ndarray] - done : Union[bool, np.ndarray] - - -class BaseBuffer: - def __init__(self, size: int): - self.size = size - self.experience = deque(maxlen=size) - - def __len__(self): - return len(self.experience) - - def add(self, experience): - self.experience.append(experience) - - def sample(self, k, cer=4): - sample = random.choices(self.experience, k=k-cer) - for i in range(cer): sample += [self.experience[-i]] - observations = torch.stack([torch.from_numpy(e.observation) for e in sample], 0).float() - next_observations = torch.stack([torch.from_numpy(e.next_observation) for e in sample], 0).float() - actions = torch.tensor([e.action for e in sample]).long() - rewards = torch.tensor([e.reward for e in sample]).float().view(-1, 1) - dones = torch.tensor([e.done for e in sample]).float().view(-1, 1) - return Experience(observations, next_observations, actions, rewards, dones) - - - - -def soft_update(local_model, target_model, tau): - # taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb - for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): - target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data) - - -def mlp_maker(dims, flatten=False, activation='elu', activation_last='identity'): - activations = {'elu': nn.ELU, 'relu': nn.ReLU, - 'leaky_relu': nn.LeakyReLU, 'tanh': nn.Tanh, - 'gelu': nn.GELU, 'identity': nn.Identity} - layers = [('Flatten', nn.Flatten())] if flatten else [] - for i in range(1, len(dims)): - layers.append((f'Layer #{i - 1}: Linear', nn.Linear(dims[i - 1], dims[i]))) - activation_str = activation if i != len(dims)-1 else activation_last - layers.append((f'Layer #{i - 1}: {activation_str.capitalize()}', activations[activation_str]())) - return nn.Sequential(OrderedDict(layers)) - - -class BaseDQN(nn.Module): - def __init__(self, dims=[3*5*5, 64, 64, 9]): - super(BaseDQN, self).__init__() - self.net = mlp_maker(dims, flatten=True) - - def act(self, x) -> np.ndarray: - with torch.no_grad(): - action = self.forward(x).max(-1)[1].numpy() - return action - - def forward(self, x): - return self.net(x) - - -class BaseDDQN(BaseDQN): - def __init__(self, - backbone_dims=[3*5*5, 64, 64], - value_dims=[64,1], - advantage_dims=[64,9]): - super(BaseDDQN, self).__init__(backbone_dims) - self.net = mlp_maker(backbone_dims, flatten=True) - self.value_head = mlp_maker(value_dims) - self.advantage_head = mlp_maker(advantage_dims) - - def forward(self, x): - features = self.net(x) - advantages = self.advantage_head(features) - values = self.value_head(features) - return values + (advantages - advantages.mean()) - - -class BaseQlearner: - def __init__(self, q_net, target_q_net, env, buffer_size, target_update, eps_end, n_agents=1, - gamma=0.99, train_every_n_steps=4, n_grad_steps=1, tau=1.0, max_grad_norm=10, - exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0): - self.q_net = q_net - self.target_q_net = target_q_net - self.target_q_net.eval() - soft_update(self.q_net, self.target_q_net, tau=1.0) - self.env = env - self.buffer = BaseBuffer(buffer_size) - self.target_update = target_update - self.eps = 1. - self.eps_end = eps_end - self.exploration_fraction = exploration_fraction - self.batch_size = batch_size - self.gamma = gamma - self.train_every_n_steps = train_every_n_steps - self.n_grad_steps = n_grad_steps - self.lr = lr - self.tau = tau - self.reg_weight = reg_weight - self.n_agents = n_agents - self.device = 'cpu' - self.optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=self.lr) - self.max_grad_norm = max_grad_norm - self.running_reward = deque(maxlen=5) - self.running_loss = deque(maxlen=5) - self._n_updates = 0 - - def to(self, device): - self.device = device - return self - - @staticmethod - def weights_init(module, activation='leaky_relu'): - if isinstance(module, (nn.Linear, nn.Conv2d)): - nn.init.orthogonal_(module.weight, gain=torch.nn.init.calculate_gain(activation)) - if module.bias is not None: - module.bias.data.fill_(0.0) - - def anneal_eps(self, step, n_steps): - fraction = min(float(step) / int(self.exploration_fraction*n_steps), 1.0) - self.eps = 1 + fraction * (self.eps_end - 1) - - def get_action(self, obs) -> Union[int, np.ndarray]: - o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs) - if np.random.rand() > self.eps: - action = self.q_net.act(o.float()) - else: - action = np.array([self.env.action_space.sample() for _ in range(self.n_agents)]) - return action - - def learn(self, n_steps): - step = 0 - while step < n_steps: - obs, done = self.env.reset(), False - total_reward = 0 - while not done: - - action = self.get_action(obs) - - next_obs, reward, done, info = self.env.step(action if not len(action) == 1 else action[0]) - - experience = Experience(observation=obs, next_observation=next_obs, action=action, reward=reward, done=done) # do we really need to copy? - self.buffer.add(experience) - # end of step routine - obs = next_obs - step += 1 - total_reward += reward - self.anneal_eps(step, n_steps) - - if step % self.train_every_n_steps == 0: - self.train() - self._n_updates += 1 - if step % self.target_update == 0: - print('UPDATE') - soft_update(self.q_net, self.target_q_net, tau=self.tau) - - self.running_reward.append(total_reward) - if step % 10 == 0: - print(f'Step: {step} ({(step/n_steps)*100:.2f}%)\tRunning reward: {sum(list(self.running_reward))/len(self.running_reward):.2f}\t' - f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss))/len(self.running_loss):.4f}\tUpdates:{self._n_updates}') - - def _training_routine(self, obs, next_obs, action): - current_q_values = self.q_net(obs) - current_q_values = torch.gather(current_q_values, dim=-1, index=action) - next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() - return current_q_values, next_q_values_raw - - def _backprop_loss(self, loss): - # log loss - self.running_loss.append(loss.item()) - # Optimize the model - self.optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) - self.optimizer.step() - - def train(self): - if len(self.buffer) < self.batch_size: return - for _ in range(self.n_grad_steps): - - experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps) - if self.n_agents <= 1: - pred_q, target_q_raw = self._training_routine(experience.observation, - experience.next_observation, - experience.action) - - else: - pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1)) - for agent_i in range(self.n_agents): - q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i], - experience.next_observation[:, agent_i], - experience.action[:, agent_i].unsqueeze(-1)) - pred_q += q_values - target_q_raw += next_q_values_raw - target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw - loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) - self._backprop_loss(loss) - - -class MunchhausenQLearner(BaseQlearner): - def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs): - super(MunchhausenQLearner, self).__init__(*args, **kwargs) - assert self.n_agents == 1, 'M-DQN currently only supports single agent training' - self.temperature = temperature - self.alpha = alpha - self.clip0 = clip_l0 - - def tau_ln_pi(self, qs): - # computes log(softmax(qs/temperature)) - # Custom log-sum-exp trick from page 18 to compute the log-policy terms - v_k = qs.max(-1)[0].unsqueeze(-1) - advantage = qs - v_k - logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1) - tau_ln_pi = advantage - self.temperature * logsum - return tau_ln_pi - - def train(self): - if len(self.buffer) < self.batch_size: return - for _ in range(self.n_grad_steps): - - experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps) - - q_target_next = self.target_q_net(experience.next_observation).detach() - tau_log_pi_next = self.tau_ln_pi(q_target_next) - - q_k_targets = self.target_q_net(experience.observation).detach() - log_pi = self.tau_ln_pi(q_k_targets) - - pi_target = F.softmax(q_target_next / self.temperature, dim=-1) - q_target = (self.gamma * (pi_target * (q_target_next - tau_log_pi_next) * (1 - experience.done)).sum(-1)).unsqueeze(-1) - - munchausen_addon = log_pi.gather(-1, experience.action) - - munchausen_reward = (experience.reward + self.alpha * torch.clamp(munchausen_addon, min=self.clip0, max=0)) - - # Compute Q targets for current states - m_q_target = munchausen_reward + q_target - - # Get expected Q values from local model - q_k = self.q_net(experience.observation) - pred_q = q_k.gather(-1, experience.action) - - # Compute loss - loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) - self._backprop_loss(loss) - - -if __name__ == '__main__': - from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties - - N_AGENTS = 1 - - dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, - max_local_amount=5, spawn_frequency=1, max_spawn_ratio=0.05) - move_props = MovementProperties(allow_diagonal_movement=True, - allow_square_movement=True, - allow_no_op=False) - env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) - - dqn, target_dqn = BaseDDQN(), BaseDDQN() - learner = MunchhausenQLearner(dqn, target_dqn, env, 40000, target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, - train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) - learner.learn(100000) diff --git a/algorithms/m_q_learner.py b/algorithms/m_q_learner.py new file mode 100644 index 0000000..402c68c --- /dev/null +++ b/algorithms/m_q_learner.py @@ -0,0 +1,53 @@ +import torch +import torch.nn.functional as F +from algorithms.q_learner import QLearner + + +class MQLearner(QLearner): + # Munchhausen Q-Learning + def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs): + super(MQLearner, self).__init__(*args, **kwargs) + assert self.n_agents == 1, 'M-DQN currently only supports single agent training' + self.temperature = temperature + self.alpha = alpha + self.clip0 = clip_l0 + + def tau_ln_pi(self, qs): + # computes log(softmax(qs/temperature)) + # Custom log-sum-exp trick from page 18 to compute the log-policy terms + v_k = qs.max(-1)[0].unsqueeze(-1) + advantage = qs - v_k + logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1) + tau_ln_pi = advantage - self.temperature * logsum + return tau_ln_pi + + def train(self): + if len(self.buffer) < self.batch_size: return + for _ in range(self.n_grad_steps): + + experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps) + + with torch.no_grad(): + q_target_next = self.target_q_net(experience.next_observation) + tau_log_pi_next = self.tau_ln_pi(q_target_next) + + q_k_targets = self.target_q_net(experience.observation) + log_pi = self.tau_ln_pi(q_k_targets) + + pi_target = F.softmax(q_target_next / self.temperature, dim=-1) + q_target = (self.gamma * (pi_target * (q_target_next - tau_log_pi_next) * (1 - experience.done)).sum(-1)).unsqueeze(-1) + + munchausen_addon = log_pi.gather(-1, experience.action) + + munchausen_reward = (experience.reward + self.alpha * torch.clamp(munchausen_addon, min=self.clip0, max=0)) + + # Compute Q targets for current states + m_q_target = munchausen_reward + q_target + + # Get expected Q values from local model + q_k = self.q_net(experience.observation) + pred_q = q_k.gather(-1, experience.action) + + # Compute loss + loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) + self._backprop_loss(loss) \ No newline at end of file diff --git a/algorithms/q_learner.py b/algorithms/q_learner.py new file mode 100644 index 0000000..d6ee864 --- /dev/null +++ b/algorithms/q_learner.py @@ -0,0 +1,144 @@ +from typing import Union +import gym +import torch +import torch.nn as nn +import numpy as np +from collections import deque +from pathlib import Path +import yaml +from algorithms.common import BaseLearner, BaseBuffer, soft_update, Experience + + +class QLearner(BaseLearner): + def __init__(self, q_net, target_q_net, env, buffer_size=1e5, target_update=3000, eps_end=0.05, n_agents=1, + gamma=0.99, train_every_n_steps=4, n_grad_steps=1, tau=1.0, max_grad_norm=10, weight_decay=1e-2, + exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0, eps_start=1): + super(QLearner, self).__init__(env, n_agents, lr) + self.q_net = q_net + self.target_q_net = target_q_net + self.target_q_net.eval() + soft_update(self.q_net, self.target_q_net, tau=1.0) + self.buffer = BaseBuffer(buffer_size) + self.target_update = target_update + self.eps = eps_start + self.eps_start = eps_start + self.eps_end = eps_end + self.exploration_fraction = exploration_fraction + self.batch_size = batch_size + self.gamma = gamma + self.train_every_n_steps = train_every_n_steps + self.n_grad_steps = n_grad_steps + self.tau = tau + self.reg_weight = reg_weight + self.weight_decay = weight_decay + self.optimizer = torch.optim.AdamW(self.q_net.parameters(), + lr=self.lr, + weight_decay=self.weight_decay) + self.max_grad_norm = max_grad_norm + self.running_reward = deque(maxlen=5) + self.running_loss = deque(maxlen=5) + self.n_updates = 0 + + def save(self, path): + path = Path(path) # no-op if already instance of Path + path.mkdir(parents=True, exist_ok=True) + hparams = {k: v for k, v in self.__dict__.items() if not(isinstance(v, BaseBuffer) or + isinstance(v, torch.optim.Optimizer) or + isinstance(v, gym.Env) or + isinstance(v, nn.Module)) + } + hparams.update({'class': self.__class__.__name__}) + with (path / 'hparams.yaml').open('w') as outfile: + yaml.dump(hparams, outfile) + torch.save(self.q_net, path / 'q_net.pt') + + def anneal_eps(self, step, n_steps): + fraction = min(float(step) / int(self.exploration_fraction*n_steps), 1.0) + self.eps = 1 + fraction * (self.eps_end - 1) + + def get_action(self, obs) -> Union[int, np.ndarray]: + o = torch.from_numpy(obs).unsqueeze(0) if self.n_agents <= 1 else torch.from_numpy(obs) + if np.random.rand() > self.eps: + action = self.q_net.act(o.float()) + else: + action = np.array([self.env.action_space.sample() for _ in range(self.n_agents)]) + return action + + def learn(self, n_steps): + step = 0 + while step < n_steps: + obs, done = self.env.reset(), False + total_reward = 0 + while not done: + + action = self.get_action(obs) + + next_obs, reward, done, info = self.env.step(action if not len(action) == 1 else action[0]) + + experience = Experience(observation=obs, next_observation=next_obs, action=action, reward=reward, done=done) # do we really need to copy? + self.buffer.add(experience) + # end of step routine + obs = next_obs + step += 1 + total_reward += reward + self.anneal_eps(step, n_steps) + + if step % self.train_every_n_steps == 0: + self.train() + self.n_updates += 1 + if step % self.target_update == 0: + print('UPDATE') + soft_update(self.q_net, self.target_q_net, tau=self.tau) + + self.running_reward.append(total_reward) + if step % 10 == 0: + print(f'Step: {step} ({(step/n_steps)*100:.2f}%)\tRunning reward: {sum(list(self.running_reward))/len(self.running_reward):.2f}\t' + f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss))/len(self.running_loss):.4f}\tUpdates:{self.n_updates}') + + def _training_routine(self, obs, next_obs, action): + current_q_values = self.q_net(obs) + current_q_values = torch.gather(current_q_values, dim=-1, index=action) + next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() + return current_q_values, next_q_values_raw + + def _backprop_loss(self, loss): + # log loss + self.running_loss.append(loss.item()) + # Optimize the model + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) + self.optimizer.step() + + def train(self): + if len(self.buffer) < self.batch_size: return + for _ in range(self.n_grad_steps): + experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps) + pred_q, target_q_raw = self._training_routine(experience.observation, + experience.next_observation, + experience.action) + target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw + loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) + self._backprop_loss(loss) + + + +if __name__ == '__main__': + from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties + from algorithms.common import BaseDDQN + from algorithms.vdn_learner import VDNLearner + + N_AGENTS = 1 + + dirt_props = DirtProperties(clean_amount=3, gain_amount=0.2, max_global_amount=30, + max_local_amount=5, spawn_frequency=1, max_spawn_ratio=0.05) + move_props = MovementProperties(allow_diagonal_movement=True, + allow_square_movement=True, + allow_no_op=False) + env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) + + dqn, target_dqn = BaseDDQN(), BaseDDQN() + learner = QLearner(dqn, target_dqn, env, 40000, target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, + train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) + #learner.save(Path(__file__).parent / 'test' / 'testexperiment1337') + learner.learn(100000) diff --git a/algorithms/vdn_learner.py b/algorithms/vdn_learner.py new file mode 100644 index 0000000..f50c6ab --- /dev/null +++ b/algorithms/vdn_learner.py @@ -0,0 +1,23 @@ +import torch +from algorithms.q_learner import QLearner + + +class VDNLearner(QLearner): + def __init__(self, *args, **kwargs): + super(VDNLearner, self).__init__(*args, **kwargs) + assert self.n_agents >= 2, 'VDN requires more than one agent, use QLearner instead' + + def train(self): + if len(self.buffer) < self.batch_size: return + for _ in range(self.n_grad_steps): + experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps) + pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1)) + for agent_i in range(self.n_agents): + q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i], + experience.next_observation[:, agent_i], + experience.action[:, agent_i].unsqueeze(-1)) + pred_q += q_values + target_q_raw += next_q_values_raw + target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw + loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) + self._backprop_loss(loss) \ No newline at end of file