Merge remote-tracking branch 'origin/main'
This commit is contained in:
commit
7c8008807f
0
algorithms/__init__.py
Normal file
0
algorithms/__init__.py
Normal file
52
algorithms/dqn_reg.py
Normal file
52
algorithms/dqn_reg.py
Normal file
@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import stable_baselines3 as sb3
|
||||
from stable_baselines3.common import logger
|
||||
|
||||
|
||||
class RegDQN(sb3.dqn.DQN):
|
||||
def __init__(self, *args, reg_weight=0.1, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.reg_weight = reg_weight
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
|
||||
# Update learning rate according to schedule
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
||||
losses = []
|
||||
for _ in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
|
||||
|
||||
with torch.no_grad():
|
||||
# Compute the next Q-values using the target network
|
||||
next_q_values = self.q_net_target(replay_data.next_observations)
|
||||
# Follow greedy policy: use the one with the highest value
|
||||
next_q_values, _ = next_q_values.max(dim=1)
|
||||
# Avoid potential broadcast issue
|
||||
next_q_values = next_q_values.reshape(-1, 1)
|
||||
# 1-step TD target
|
||||
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
|
||||
|
||||
# Get current Q-values estimates
|
||||
current_q_values = self.q_net(replay_data.observations)
|
||||
|
||||
# Retrieve the q-values for the actions from the replay buffer
|
||||
current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long())
|
||||
|
||||
delta = current_q_values - target_q_values
|
||||
loss = torch.mean(self.reg_weight * current_q_values + torch.pow(delta, 2))
|
||||
losses.append(loss.item())
|
||||
|
||||
# Optimize the policy
|
||||
self.policy.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# Clip gradient norm
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
||||
self.policy.optimizer.step()
|
||||
|
||||
# Increase update counter
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/loss", np.mean(losses))
|
Loading…
x
Reference in New Issue
Block a user