From 543c2987e024e8253d2958c2865cdf151d2969e9 Mon Sep 17 00:00:00 2001 From: romue Date: Mon, 21 Jun 2021 14:05:20 +0200 Subject: [PATCH] add Munchhausen DQN refactoring --- algorithms/_base.py | 55 ++++++++++++++------------------------------- 1 file changed, 17 insertions(+), 38 deletions(-) diff --git a/algorithms/_base.py b/algorithms/_base.py index b9d7d83..520c690 100644 --- a/algorithms/_base.py +++ b/algorithms/_base.py @@ -1,22 +1,20 @@ -from typing import NamedTuple, Union, Iterable -from collections import namedtuple, deque +from typing import NamedTuple, Union +from collections import deque import numpy as np import random import torch import torch.nn as nn import torch.nn.functional as F -from stable_baselines3.common.utils import polyak_update -from stable_baselines3.common.buffers import ReplayBuffer -import copy class Experience(NamedTuple): + # can be use for a single (s_t, a, r s_{t+1}) tuple + # or for a batch of tuples observation: np.ndarray next_observation: np.ndarray action: np.ndarray reward: Union[float, np.ndarray] done : Union[bool, np.ndarray] - priority: np.ndarray = 1 class BaseBuffer: @@ -65,9 +63,6 @@ class BaseDQN(nn.Module): values = self.value_head(features) return values + (advantages - advantages.mean()) - def random_action(self): - return random.randrange(0, 5) - def soft_update(local_model, target_model, tau): # taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb @@ -165,6 +160,14 @@ class BaseQlearner: next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() return current_q_values, next_q_values_raw + def _backprop_loss(self, loss): + # log loss + self.running_loss.append(loss.item()) + # Optimize the model + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) + self.optimizer.step() def train(self): if len(self.buffer) < self.batch_size: return @@ -188,14 +191,7 @@ class BaseQlearner: target_q_raw += next_q_values_raw target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) - - # log loss - self.running_loss.append(loss.item()) - # Optimize the model - self.optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) - self.optimizer.step() + self._backprop_loss(loss) class MDQN(BaseQlearner): @@ -207,7 +203,8 @@ class MDQN(BaseQlearner): self.clip0 = clip_l0 def tau_ln_pi(self, qs): - # Custom log-sum-exp trick from page 18 to compute the e log-policy terms + # computes log(softmax(qs/temperature)) + # Custom log-sum-exp trick from page 18 to compute the log-policy terms v_k = qs.max(-1)[0].unsqueeze(-1) advantage = qs - v_k logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1) @@ -242,21 +239,11 @@ class MDQN(BaseQlearner): # Compute loss loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) - - # log loss - self.running_loss.append(loss.item()) - # Optimize the model - self.optimizer.zero_grad() - loss.backward() - torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) - self.optimizer.step() - + self._backprop_loss(loss) if __name__ == '__main__': from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties - from algorithms.reg_dqn import RegDQN - from stable_baselines3.common.vec_env import DummyVecEnv N_AGENTS = 1 @@ -266,16 +253,8 @@ if __name__ == '__main__': allow_square_movement=True, allow_no_op=False) env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) - #env = DummyVecEnv([lambda: env]) - from stable_baselines3.dqn import DQN - - #dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 40000, learning_starts = 0, batch_size = 64,learning_rate=0.0008, - # target_update_interval = 3500, exploration_fraction = 0.25, exploration_final_eps = 0.05, - # train_freq=4, gradient_steps=1, reg_weight=0.05, seed=69) - #dqn.learn(100000) - dqn, target_dqn = BaseDQN(), BaseDQN() - learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, + learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) learner.learn(100000)