add Munchhausen DQN refactoring

This commit is contained in:
romue 2021-06-21 14:05:20 +02:00
parent bd63c603ee
commit 543c2987e0

View File

@ -1,22 +1,20 @@
from typing import NamedTuple, Union, Iterable from typing import NamedTuple, Union
from collections import namedtuple, deque from collections import deque
import numpy as np import numpy as np
import random import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.common.buffers import ReplayBuffer
import copy
class Experience(NamedTuple): class Experience(NamedTuple):
# can be use for a single (s_t, a, r s_{t+1}) tuple
# or for a batch of tuples
observation: np.ndarray observation: np.ndarray
next_observation: np.ndarray next_observation: np.ndarray
action: np.ndarray action: np.ndarray
reward: Union[float, np.ndarray] reward: Union[float, np.ndarray]
done : Union[bool, np.ndarray] done : Union[bool, np.ndarray]
priority: np.ndarray = 1
class BaseBuffer: class BaseBuffer:
@ -65,9 +63,6 @@ class BaseDQN(nn.Module):
values = self.value_head(features) values = self.value_head(features)
return values + (advantages - advantages.mean()) return values + (advantages - advantages.mean())
def random_action(self):
return random.randrange(0, 5)
def soft_update(local_model, target_model, tau): def soft_update(local_model, target_model, tau):
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb # taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
@ -165,6 +160,14 @@ class BaseQlearner:
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
return current_q_values, next_q_values_raw return current_q_values, next_q_values_raw
def _backprop_loss(self, loss):
# log loss
self.running_loss.append(loss.item())
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step()
def train(self): def train(self):
if len(self.buffer) < self.batch_size: return if len(self.buffer) < self.batch_size: return
@ -188,14 +191,7 @@ class BaseQlearner:
target_q_raw += next_q_values_raw target_q_raw += next_q_values_raw
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2)) loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
self._backprop_loss(loss)
# log loss
self.running_loss.append(loss.item())
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step()
class MDQN(BaseQlearner): class MDQN(BaseQlearner):
@ -207,7 +203,8 @@ class MDQN(BaseQlearner):
self.clip0 = clip_l0 self.clip0 = clip_l0
def tau_ln_pi(self, qs): def tau_ln_pi(self, qs):
# Custom log-sum-exp trick from page 18 to compute the e log-policy terms # computes log(softmax(qs/temperature))
# Custom log-sum-exp trick from page 18 to compute the log-policy terms
v_k = qs.max(-1)[0].unsqueeze(-1) v_k = qs.max(-1)[0].unsqueeze(-1)
advantage = qs - v_k advantage = qs - v_k
logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1) logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1)
@ -242,21 +239,11 @@ class MDQN(BaseQlearner):
# Compute loss # Compute loss
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
self._backprop_loss(loss)
# log loss
self.running_loss.append(loss.item())
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step()
if __name__ == '__main__': if __name__ == '__main__':
from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties
from algorithms.reg_dqn import RegDQN
from stable_baselines3.common.vec_env import DummyVecEnv
N_AGENTS = 1 N_AGENTS = 1
@ -266,16 +253,8 @@ if __name__ == '__main__':
allow_square_movement=True, allow_square_movement=True,
allow_no_op=False) allow_no_op=False)
env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True) env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True)
#env = DummyVecEnv([lambda: env])
from stable_baselines3.dqn import DQN
#dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 40000, learning_starts = 0, batch_size = 64,learning_rate=0.0008,
# target_update_interval = 3500, exploration_fraction = 0.25, exploration_final_eps = 0.05,
# train_freq=4, gradient_steps=1, reg_weight=0.05, seed=69)
#dqn.learn(100000)
dqn, target_dqn = BaseDQN(), BaseDQN() dqn, target_dqn = BaseDQN(), BaseDQN()
learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64)
learner.learn(100000) learner.learn(100000)