add Munchhausen DQN refactoring

This commit is contained in:
romue 2021-06-21 14:05:20 +02:00
parent bd63c603ee
commit 543c2987e0

View File

@ -1,22 +1,20 @@
from typing import NamedTuple, Union, Iterable
from collections import namedtuple, deque
from typing import NamedTuple, Union
from collections import deque
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.common.buffers import ReplayBuffer
import copy
class Experience(NamedTuple):
# can be use for a single (s_t, a, r s_{t+1}) tuple
# or for a batch of tuples
observation: np.ndarray
next_observation: np.ndarray
action: np.ndarray
reward: Union[float, np.ndarray]
done : Union[bool, np.ndarray]
priority: np.ndarray = 1
class BaseBuffer:
@ -65,9 +63,6 @@ class BaseDQN(nn.Module):
values = self.value_head(features)
return values + (advantages - advantages.mean())
def random_action(self):
return random.randrange(0, 5)
def soft_update(local_model, target_model, tau):
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
@ -165,6 +160,14 @@ class BaseQlearner:
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
return current_q_values, next_q_values_raw
def _backprop_loss(self, loss):
# log loss
self.running_loss.append(loss.item())
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step()
def train(self):
if len(self.buffer) < self.batch_size: return
@ -188,14 +191,7 @@ class BaseQlearner:
target_q_raw += next_q_values_raw
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
# log loss
self.running_loss.append(loss.item())
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step()
self._backprop_loss(loss)
class MDQN(BaseQlearner):
@ -207,7 +203,8 @@ class MDQN(BaseQlearner):
self.clip0 = clip_l0
def tau_ln_pi(self, qs):
# Custom log-sum-exp trick from page 18 to compute the e log-policy terms
# computes log(softmax(qs/temperature))
# Custom log-sum-exp trick from page 18 to compute the log-policy terms
v_k = qs.max(-1)[0].unsqueeze(-1)
advantage = qs - v_k
logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1)
@ -242,21 +239,11 @@ class MDQN(BaseQlearner):
# Compute loss
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
# log loss
self.running_loss.append(loss.item())
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step()
self._backprop_loss(loss)
if __name__ == '__main__':
from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties
from algorithms.reg_dqn import RegDQN
from stable_baselines3.common.vec_env import DummyVecEnv
N_AGENTS = 1
@ -266,16 +253,8 @@ if __name__ == '__main__':
allow_square_movement=True,
allow_no_op=False)
env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True)
#env = DummyVecEnv([lambda: env])
from stable_baselines3.dqn import DQN
#dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 40000, learning_starts = 0, batch_size = 64,learning_rate=0.0008,
# target_update_interval = 3500, exploration_fraction = 0.25, exploration_final_eps = 0.05,
# train_freq=4, gradient_steps=1, reg_weight=0.05, seed=69)
#dqn.learn(100000)
dqn, target_dqn = BaseDQN(), BaseDQN()
learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64)
learner.learn(100000)