2021-06-04 17:17:29 +02:00

52 lines
2.1 KiB
Python

import numpy as np
import torch
import stable_baselines3 as sb3
from stable_baselines3.common import logger
class RegDQN(sb3.dqn.DQN):
def __init__(self, *args, reg_weight=0.1, **kwargs):
super().__init__(*args, **kwargs)
self.reg_weight = reg_weight
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)
losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
with torch.no_grad():
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates
current_q_values = self.q_net(replay_data.observations)
# Retrieve the q-values for the actions from the replay buffer
current_q_values = torch.gather(current_q_values, dim=1, index=replay_data.actions.long())
delta = current_q_values - target_q_values
loss = torch.mean(self.reg_weight * current_q_values + torch.pow(delta, 2))
losses.append(loss.item())
# Optimize the policy
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# Increase update counter
self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/loss", np.mean(losses))