add CER sampling and Munchhausen DQN

This commit is contained in:
romue 2021-06-18 13:55:38 +02:00
parent eee4760e72
commit e541e34270

View File

@ -1,4 +1,4 @@
from typing import NamedTuple, Union from typing import NamedTuple, Union, Iterable
from collections import namedtuple, deque from collections import namedtuple, deque
import numpy as np import numpy as np
import random import random
@ -30,8 +30,9 @@ class BaseBuffer:
def add(self, experience): def add(self, experience):
self.experience.append(experience) self.experience.append(experience)
def sample(self, k): def sample(self, k, cer=4):
sample = random.choices(self.experience, k=k) sample = random.choices(self.experience, k=k-cer)
for i in range(cer): sample += [self.experience[-i]]
observations = torch.stack([torch.from_numpy(e.observation) for e in sample], 0).float() observations = torch.stack([torch.from_numpy(e.observation) for e in sample], 0).float()
next_observations = torch.stack([torch.from_numpy(e.next_observation) for e in sample], 0).float() next_observations = torch.stack([torch.from_numpy(e.next_observation) for e in sample], 0).float()
actions = torch.tensor([e.action for e in sample]).long() actions = torch.tensor([e.action for e in sample]).long()
@ -40,18 +41,6 @@ class BaseBuffer:
return Experience(observations, next_observations, actions, rewards, dones) return Experience(observations, next_observations, actions, rewards, dones)
class PERBuffer(BaseBuffer):
def __init__(self, size, alpha=0.2):
super(PERBuffer, self).__init__(size)
self.alpha = alpha
def sample(self, k):
pr = [abs(e.priority)**self.alpha for e in self.experience]
pr = np.array(pr) / sum(pr)
idxs = random.choices(range(len(self)), weights=pr, k=k)
pass
class BaseDQN(nn.Module): class BaseDQN(nn.Module):
def __init__(self): def __init__(self):
super(BaseDQN, self).__init__() super(BaseDQN, self).__init__()
@ -80,14 +69,21 @@ class BaseDQN(nn.Module):
return random.randrange(0, 5) return random.randrange(0, 5)
def soft_update(local_model, target_model, tau):
# taken from https://github.com/BY571/Munchausen-RL/blob/master/M-DQN.ipynb
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(tau*local_param.data + (1.-tau)*target_param.data)
class BaseQlearner: class BaseQlearner:
def __init__(self, q_net, target_q_net, env, buffer, target_update, eps_end, n_agents=1, def __init__(self, q_net, target_q_net, env, buffer, target_update, eps_end, n_agents=1,
gamma=0.99, train_every_n_steps=4, n_grad_steps=1, gamma=0.99, train_every_n_steps=4, n_grad_steps=1, tau=1.0, max_grad_norm=10,
exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0): exploration_fraction=0.2, batch_size=64, lr=1e-4, reg_weight=0.0):
self.q_net = q_net self.q_net = q_net
self.target_q_net = target_q_net self.target_q_net = target_q_net
self.q_net.apply(self.weights_init) #self.q_net.apply(self.weights_init)
self.target_q_net.eval() self.target_q_net.eval()
soft_update(self.q_net, self.target_q_net, tau=1.0)
self.env = env self.env = env
self.buffer = buffer self.buffer = buffer
self.target_update = target_update self.target_update = target_update
@ -99,10 +95,12 @@ class BaseQlearner:
self.train_every_n_steps = train_every_n_steps self.train_every_n_steps = train_every_n_steps
self.n_grad_steps = n_grad_steps self.n_grad_steps = n_grad_steps
self.lr = lr self.lr = lr
self.tau = tau
self.reg_weight = reg_weight self.reg_weight = reg_weight
self.n_agents = n_agents self.n_agents = n_agents
self.device = 'cpu' self.device = 'cpu'
self.optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=self.lr) self.optimizer = torch.optim.AdamW(self.q_net.parameters(), lr=self.lr)
self.max_grad_norm = max_grad_norm
self.running_reward = deque(maxlen=5) self.running_reward = deque(maxlen=5)
self.running_loss = deque(maxlen=5) self.running_loss = deque(maxlen=5)
self._n_updates = 0 self._n_updates = 0
@ -112,7 +110,7 @@ class BaseQlearner:
return self return self
@staticmethod @staticmethod
def weights_init(module, activation='relu'): def weights_init(module, activation='leaky_relu'):
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
nn.init.xavier_normal_(module.weight, gain=torch.nn.init.calculate_gain(activation)) nn.init.xavier_normal_(module.weight, gain=torch.nn.init.calculate_gain(activation))
if module.bias is not None: if module.bias is not None:
@ -154,35 +152,38 @@ class BaseQlearner:
self._n_updates += 1 self._n_updates += 1
if step % self.target_update == 0: if step % self.target_update == 0:
print('UPDATE') print('UPDATE')
polyak_update(self.q_net.parameters(), self.target_q_net.parameters(), 1) soft_update(self.q_net, self.target_q_net, tau=self.tau)
self.running_reward.append(total_reward) self.running_reward.append(total_reward)
if step % 10 == 0: if step % 10 == 0:
print(f'Step: {step} ({(step/n_steps)*100:.2f}%)\tRunning reward: {sum(list(self.running_reward))/len(self.running_reward):.2f}\t' print(f'Step: {step} ({(step/n_steps)*100:.2f}%)\tRunning reward: {sum(list(self.running_reward))/len(self.running_reward):.2f}\t'
f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss))/len(self.running_loss):.4f}\tUpdates:{self._n_updates}') f' eps: {self.eps:.4f}\tRunning loss: {sum(list(self.running_loss))/len(self.running_loss):.4f}\tUpdates:{self._n_updates}')
def _training_routine(self, obs, next_obs, action): def _training_routine(self, obs, next_obs, action, reward):
current_q_values = self.q_net(obs) current_q_values = self.q_net(obs)
current_q_values = torch.gather(current_q_values, dim=-1, index=action) current_q_values = torch.gather(current_q_values, dim=-1, index=action)
next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach() next_q_values_raw = self.target_q_net(next_obs).max(dim=-1)[0].reshape(-1, 1).detach()
return current_q_values, next_q_values_raw return current_q_values, next_q_values_raw
def train(self): def train(self):
if len(self.buffer) < self.batch_size: return if len(self.buffer) < self.batch_size: return
for _ in range(self.n_grad_steps): for _ in range(self.n_grad_steps):
experience = self.buffer.sample(self.batch_size) experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps)
#print(experience.observation.shape, experience.next_observation.shape, experience.action.shape, experience.reward.shape, experience.done.shape)
if self.n_agents <= 1: if self.n_agents <= 1:
pred_q, target_q_raw = self._training_routine(experience.observation, experience.next_observation, experience.action) pred_q, target_q_raw = self._training_routine(experience.observation,
experience.next_observation,
experience.action,
experience.reward)
else: else:
pred_q, target_q_raw = torch.zeros((self.batch_size, 1)), torch.zeros((self.batch_size, 1)) pred_q, target_q_raw, reward = [torch.zeros((self.batch_size, 1))]*3
for agent_i in range(self.n_agents): for agent_i in range(self.n_agents):
q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i], q_values, next_q_values_raw = self._training_routine(experience.observation[:, agent_i],
experience.next_observation[:, agent_i], experience.next_observation[:, agent_i],
experience.action[:, agent_i].unsqueeze(-1) experience.action[:, agent_i].unsqueeze(-1),
) experience.reward)
pred_q += q_values pred_q += q_values
target_q_raw += next_q_values_raw target_q_raw += next_q_values_raw
target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw target_q = experience.reward + (1 - experience.done) * self.gamma * target_q_raw
@ -193,7 +194,56 @@ class BaseQlearner:
# Optimize the model # Optimize the model
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10) torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step()
class MDQN(BaseQlearner):
def __init__(self, *args, temperature=0.03, alpha=0.9, clip_l0=-1.0, **kwargs):
super(MDQN, self).__init__(*args, **kwargs)
assert self.n_agents == 1, 'M-DQN currently only supports single agent training'
self.temperature = temperature
self.alpha = alpha
self.clip0 = clip_l0
def train(self):
if len(self.buffer) < self.batch_size: return
for _ in range(self.n_grad_steps):
experience = self.buffer.sample(self.batch_size, cer=self.train_every_n_steps)
q_target_next = self.target_q_net(experience.next_observation).detach()
advantages_next = (q_target_next - q_target_next.max(-1)[0].unsqueeze(-1))
logsum = torch.logsumexp(advantages_next / self.temperature, -1).unsqueeze(-1)
tau_log_pi_next = advantages_next - self.temperature * logsum
pi_target = F.softmax(q_target_next / self.temperature, dim=-1)
q_target = (self.gamma * (pi_target * (q_target_next - tau_log_pi_next) * (1 - experience.done)).sum(-1)).unsqueeze(-1)
q_k_targets = self.target_q_net(experience.observation).detach()
v_k_target = q_k_targets.max(-1)[0].unsqueeze(-1)
logsum = torch.logsumexp((q_k_targets - v_k_target) / self.temperature, -1).unsqueeze(-1)
log_pi = q_k_targets - v_k_target - self.temperature * logsum
munchausen_addon = log_pi.gather(-1, experience.action)
munchausen_reward = (experience.reward + self.alpha * torch.clamp(munchausen_addon, min=self.clip0, max=0))
# Compute Q targets for current states
m_q_target = munchausen_reward + q_target
# Get expected Q values from local model
q_k = self.q_net(experience.observation)
pred_q = q_k.gather(-1, experience.action)
# Compute loss
loss = torch.mean(self.reg_weight * pred_q + torch.pow(pred_q - m_q_target, 2))
# log loss
self.running_loss.append(loss.item())
# Optimize the model
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
self.optimizer.step() self.optimizer.step()
@ -221,6 +271,6 @@ if __name__ == '__main__':
dqn, target_dqn = BaseDQN(), BaseDQN() dqn, target_dqn = BaseDQN(), BaseDQN()
learner = BaseQlearner(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, learner = MDQN(dqn, target_dqn, env, BaseBuffer(40000), target_update=3500, lr=0.0008, gamma=0.99, n_agents=N_AGENTS, tau=0.95, max_grad_norm=10,
train_every_n_steps=4, eps_end=0.05, n_grad_steps=1, reg_weight=0.05, exploration_fraction=0.25, batch_size=64) train_every_n_steps=4, eps_end=0.025, n_grad_steps=1, reg_weight=0.1, exploration_fraction=0.25, batch_size=64)
learner.learn(100000) learner.learn(100000)