import numpy as np import torch import stable_baselines3 as sb3 from stable_baselines3.common import logger class RegDQN(sb3.dqn.DQN): def __init__(self, *args, reg_weight=0.1, **kwargs): super().__init__(*args, **kwargs) self.reg_weight = reg_weight def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Update learning rate according to schedule self._update_learning_rate(self.policy.optimizer) losses = [] for _ in range(gradient_steps): # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) with torch.no_grad(): # Compute the next Q-values using the target network next_q_values = self.q_net_target(replay_data.next_observations) # Follow greedy policy: use the one with the highest value next_q_values, _ = next_q_values.max(dim=1) # Avoid potential broadcast issue next_q_values = next_q_values.reshape(-1, 1) # 1-step TD target target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values # Get current Q-values estimates current_q_values = self.q_net(replay_data.observations) # Retrieve the q-values for the actions from the replay buffer current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long()) delta = current_q_values - target_q_values loss = torch.mean(self.reg_weight * current_q_values + torch.pow(delta, 2)) losses.append(loss.item()) # Optimize the policy self.policy.optimizer.zero_grad() loss.backward() # Clip gradient norm torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() # Increase update counter self._n_updates += gradient_steps logger.record("train/n_updates", self._n_updates, exclude="tensorboard") logger.record("train/loss", np.mean(losses))