add Munchhausen DQN refactoring
This commit is contained in:
parent
bd63c603ee
commit
543c2987e0
@ -1,22 +1,20 @@
|
||||
from typing import NamedTuple, Union, Iterable
|
||||
from collections import namedtuple, deque
|
||||
from typing import NamedTuple, Union
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from stable_baselines3.common.utils import polyak_update
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
import copy
|
||||
|
||||
|
||||
class Experience(NamedTuple):
|
||||
# can be use for a single (s_t, a, r s_{t+1}) tuple
|
||||
# or for a batch of tuples
|
||||
observation: np.ndarray
|
||||
next_observation: np.ndarray
|
||||
action: np.ndarray
|
||||
reward: Union[float, np.ndarray]
|
||||
done : Union[bool, np.ndarray]
|
||||
priority: np.ndarray = 1
|
||||
|
||||
|
||||
class BaseBuffer:
|
||||
@ -65,9 +63,6 @@ class BaseDQN(nn.Module):
|
||||
values = self.value_head(features)
|
||||
return values + (advantages - advantages.mean())
|
||||
|
||||
def random_action(self):
|
||||
return random.randrange(0, 5)
|
||||
|
||||
|
||||
def soft_update(local_model, target_model, tau):
|
||||
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
|
||||
@ -165,6 +160,14 @@ class BaseQlearner:
|
||||
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
|
||||
return current_q_values, next_q_values_raw
|
||||
|
||||
def _backprop_loss(self, loss):
|
||||
# log loss
|
||||
self.running_loss.append(loss.item())
|
||||
# Optimize the model
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
|
||||
def train(self):
|
||||
if len(self.buffer) < self.batch_size: return
|
||||
@ -188,14 +191,7 @@ class BaseQlearner:
|
||||
target_q_raw += next_q_values_raw
|
||||
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
|
||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - target_q, 2))
|
||||
|
||||
# log loss
|
||||
self.running_loss.append(loss.item())
|
||||
# Optimize the model
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
self._backprop_loss(loss)
|
||||
|
||||
|
||||
class MDQN(BaseQlearner):
|
||||
@ -207,7 +203,8 @@ class MDQN(BaseQlearner):
|
||||
self.clip0 = clip_l0
|
||||
|
||||
def tau_ln_pi(self, qs):
|
||||
# Custom log-sum-exp trick from page 18 to compute the e log-policy terms
|
||||
# computes log(softmax(qs/temperature))
|
||||
# Custom log-sum-exp trick from page 18 to compute the log-policy terms
|
||||
v_k = qs.max(-1)[0].unsqueeze(-1)
|
||||
advantage = qs - v_k
|
||||
logsum = torch.logsumexp(advantage / self.temperature, -1).unsqueeze(-1)
|
||||
@ -242,21 +239,11 @@ class MDQN(BaseQlearner):
|
||||
|
||||
# Compute loss
|
||||
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
|
||||
|
||||
# log loss
|
||||
self.running_loss.append(loss.item())
|
||||
# Optimize the model
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
|
||||
self.optimizer.step()
|
||||
|
||||
self._backprop_loss(loss)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from environments.factory.simple_factory import SimpleFactory, DirtProperties, MovementProperties
|
||||
from algorithms.reg_dqn import RegDQN
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
N_AGENTS = 1
|
||||
|
||||
@ -266,16 +253,8 @@ if __name__ == '__main__':
|
||||
allow_square_movement=True,
|
||||
allow_no_op=False)
|
||||
env = SimpleFactory(dirt_properties=dirt_props, movement_properties=move_props, n_agents=N_AGENTS, pomdp_radius=2, max_steps=400, omit_agent_slice_in_obs=False, combin_agent_slices_in_obs=True)
|
||||
#env = DummyVecEnv([lambda: env])
|
||||
from stable_baselines3.dqn import DQN
|
||||
|
||||
#dqn = RegDQN('MlpPolicy', env, verbose=True, buffer_size = 40000, learning_starts = 0, batch_size = 64,learning_rate=0.0008,
|
||||
# target_update_interval = 3500, exploration_fraction = 0.25, exploration_final_eps = 0.05,
|
||||
# train_freq=4, gradient_steps=1, reg_weight=0.05, seed=69)
|
||||
#dqn.learn(100000)
|
||||
|
||||
|
||||
dqn, target_dqn = BaseDQN(), BaseDQN()
|
||||
learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
||||
learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0007, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
|
||||
train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64)
|
||||
learner.learn(100000)
|
||||
|
Loading…
x
Reference in New Issue
Block a user