From e541e3427036bf5a21b1c8c5800f6bfff9dd8581 Mon Sep 17 00:00:00 2001 From: romue Date: Fri, 18 Jun 2021 13:55:38 +0200 Subject: [PATCH] add CER sampling and Munchhausen DQN --- algorithms/_base.py | 112 ++++++++++++++++++++++++++++++++------------ 1 file changed, 81 insertions(+), 31 deletions(-) diff --git a/algorithms/_base.py b/algorithms/_base.py index 8ca0aef..bbb2b4a 100644 --- a/algorithms/_base.py +++ b/algorithms/_base.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Union +from typing import NamedTuple, Union, Iterable from collections import namedtuple, deque import numpy as np import random @@ -30,8 +30,9 @@ class BaseBuffer: def add(self, experience): self.experience.append(experience) - def sample(self, k): - sample = random.choices(self.experience, k=k) + def sample(self, k, cer=4): + sample = random.choices(self.experience, k=k-cer) + for i in range(cer): sample += [self.experience[-i]] observations = torch.stack([torch.from_numpy(e.observation) for e in sample], 0).float() next_observations = torch.stack([torch.from_numpy(e.next_observation) for e in sample], 0).float() actions = torch.tensor([e.action for e in sample]).long() @@ -40,18 +41,6 @@ class BaseBuffer: return Experience(observations, next_observations, actions, rewards, dones) -class PERBuffer(BaseBuffer): - def __init__(self, size, alpha=0.2): - super(PERBuffer, self).__init__(size) - self.alpha = alpha - - def sample(self, k): - pr = [abs(e.priority)**self.alpha for e in self.experience] - pr = np.array(pr) / sum(pr) - idxs = random.choices(range(len(self)), weights=pr, k=k) - pass - - class BaseDQN(nn.Module): def __init__(self): super(BaseDQN, self).__init__() @@ -80,14 +69,21 @@ class BaseDQN(nn.Module): return random.randrange(0, 5) +def soft_update(local_model, target_model, tau): + # taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data) + + class BaseQlearner: def __init__(self, q_net, target_q_net, env, buffer, target_update, eps_end, n_agents=1, - gamma=0.99, train_every_n_steps=4, n_grad_steps=1, + gamma=0.99, train_every_n_steps=4, n_grad_steps=1, tau=1.0, max_grad_norm=10, exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0): self.q_net = q_net self.target_q_net = target_q_net - self.q_net.apply(self.weights_init) + #self.q_net.apply(self.weights_init) self.target_q_net.eval() + soft_update(self.q_net, self.target_q_net, tau=1.0) self.env = env self.buffer = buffer self.target_update = target_update @@ -99,10 +95,12 @@ class BaseQlearner: self.train_every_n_steps = train_every_n_steps self.n_grad_steps = n_grad_steps self.lr = lr + self.tau = tau self.reg_weight = reg_weight self.n_agents = n_agents self.device = 'cpu' self.optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=self.lr) + self.max_grad_norm = max_grad_norm self.running_reward = deque(maxlen=5) self.running_loss = deque(maxlen=5) self._n_updates = 0 @@ -112,7 +110,7 @@ class BaseQlearner: return self @staticmethod - def weights_init(module, activation='relu'): + def weights_init(module, activation='leaky_relu'): if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.xavier_normal_(module.weight, gain=torch.nn.init.calculate_gain(activation)) if module.bias is not None: @@ -154,35 +152,38 @@ class BaseQlearner: self._n_updates += 1 if step % self.target_update == 0: print('UPDATE') - polyak_update(self.q_net.parameters(), self.target_q_net.parameters(), 1) - + soft_update(self.q_net, self.target_q_net, tau=self.tau) self.running_reward.append(total_reward) if step % 10 == 0: print(f'Step: {step} ({(step/n_steps)*100:.2f}%)\tRunning reward: {sum(list(self.running_reward))/len(self.running_reward):.2f}\t' f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss))/len(self.running_loss):.4f}\tUpdates:{self._n_updates}') - def _training_routine(self, obs, next_obs, action): + def _training_routine(self, obs, next_obs, action, reward): current_q_values = self.q_net(obs) current_q_values = torch.gather(current_q_values, dim=-1, index=action) next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() return current_q_values, next_q_values_raw + def train(self): if len(self.buffer) < self.batch_size: return for _ in range(self.n_grad_steps): - experience = self.buffer.sample(self.batch_size) - #print(experience.observation.shape, experience.next_observation.shape, experience.action.shape, experience.reward.shape, experience.done.shape) + experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps) + if self.n_agents <= 1: - pred_q, target_q_raw = self._training_routine(experience.observation, experience.next_observation, experience.action) + pred_q, target_q_raw = self._training_routine(experience.observation, + experience.next_observation, + experience.action, + experience.reward) else: - pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1)) + pred_q, target_q_raw, reward = [torch.zeros((self.batch_size, 1))]*3 for agent_i in range(self.n_agents): q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i], - experience.next_observation[:, agent_i], - experience.action[:, agent_i].unsqueeze(-1) - ) + experience.next_observation[:, agent_i], + experience.action[:, agent_i].unsqueeze(-1), + experience.reward) pred_q += q_values target_q_raw += next_q_values_raw target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw @@ -193,7 +194,56 @@ class BaseQlearner: # Optimize the model self.optimizer.zero_grad() loss.backward() - torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10) + torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) + self.optimizer.step() + + +class MDQN(BaseQlearner): + def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs): + super(MDQN, self).__init__(*args, **kwargs) + assert self.n_agents == 1, 'M-DQN currently only supports single agent training' + self.temperature = temperature + self.alpha = alpha + self.clip0 = clip_l0 + + def train(self): + if len(self.buffer) < self.batch_size: return + for _ in range(self.n_grad_steps): + + experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps) + + q_target_next = self.target_q_net(experience.next_observation).detach() + advantages_next = (q_target_next - q_target_next.max(-1)[0].unsqueeze(-1)) + logsum = torch.logsumexp(advantages_next / self.temperature, -1).unsqueeze(-1) + tau_log_pi_next = advantages_next - self.temperature * logsum + + pi_target = F.softmax(q_target_next / self.temperature, dim=-1) + q_target = (self.gamma * (pi_target * (q_target_next - tau_log_pi_next) * (1 - experience.done)).sum(-1)).unsqueeze(-1) + + q_k_targets = self.target_q_net(experience.observation).detach() + v_k_target = q_k_targets.max(-1)[0].unsqueeze(-1) + logsum = torch.logsumexp((q_k_targets - v_k_target) / self.temperature, -1).unsqueeze(-1) + log_pi = q_k_targets - v_k_target - self.temperature * logsum + munchausen_addon = log_pi.gather(-1, experience.action) + + munchausen_reward = (experience.reward + self.alpha * torch.clamp(munchausen_addon, min=self.clip0, max=0)) + + # Compute Q targets for current states + m_q_target = munchausen_reward + q_target + + # Get expected Q values from local model + q_k = self.q_net(experience.observation) + pred_q = q_k.gather(-1, experience.action) + + # Compute loss + loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2)) + + # log loss + self.running_loss.append(loss.item()) + # Optimize the model + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm) self.optimizer.step() @@ -221,6 +271,6 @@ if __name__ == '__main__': dqn, target_dqn = BaseDQN(), BaseDQN() - learner = BaseQlearner(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, - train_every_n_steps=4, eps_end=0.05, n_grad_steps=1, reg_weight=0.05, exploration_fraction=0.25, batch_size=64) + learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10, + train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64) learner.learn(100000)