From e0a11b9421cad8571f9ae50fc25bd025901e3bb2 Mon Sep 17 00:00:00 2001 From: romue Date: Tue, 25 May 2021 14:30:10 +0200 Subject: [PATCH] added RegDQN --- algorithms/__init__.py | 0 algorithms/dqn_reg.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 algorithms/__init__.py create mode 100644 algorithms/dqn_reg.py diff --git a/algorithms/__init__.py b/algorithms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithms/dqn_reg.py b/algorithms/dqn_reg.py new file mode 100644 index 0000000..34ec42b --- /dev/null +++ b/algorithms/dqn_reg.py @@ -0,0 +1,52 @@ +import numpy as np +import torch +import stable_baselines3 as sb3 +from stable_baselines3.common import logger + + +class RegDQN(sb3.dqn.DQN): + def __init__(self, *args, reg_weight=0.1, **kwargs): + super().__init__(*args, **kwargs) + self.reg_weight = reg_weight + + def train(self, gradient_steps: int, batch_size: int = 100) -> None: + # Update learning rate according to schedule + self._update_learning_rate(self.policy.optimizer) + + losses = [] + for _ in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + with torch.no_grad(): + # Compute the next Q-values using the target network + next_q_values = self.q_net_target(replay_data.next_observations) + # Follow greedy policy: use the one with the highest value + next_q_values, _ = next_q_values.max(dim=1) + # Avoid potential broadcast issue + next_q_values = next_q_values.reshape(-1, 1) + # 1-step TD target + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values + + # Get current Q-values estimates + current_q_values = self.q_net(replay_data.observations) + + # Retrieve the q-values for the actions from the replay buffer + current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long()) + + delta = current_q_values - target_q_values + loss = torch.mean(self.reg_weight * current_q_values + torch.pow(delta, 2)) + losses.append(loss.item()) + + # Optimize the policy + self.policy.optimizer.zero_grad() + loss.backward() + # Clip gradient norm + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + # Increase update counter + self._n_updates += gradient_steps + + logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + logger.record("train/loss", np.mean(losses)) \ No newline at end of file