Code Comments, Getting Dirty Env, Naming
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -41,3 +43,33 @@ class FocalLossRob(nn.Module):
|
||||
return x.sum()
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class DQN_MSELoss(object):
|
||||
|
||||
def __init__(self, agent_net, target_net, gamma):
|
||||
self.agent_net = agent_net
|
||||
self.target_net = target_net
|
||||
self.gamma = gamma
|
||||
|
||||
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the mse loss using a mini batch from the replay buffer
|
||||
Args:
|
||||
batch: current mini batch of replay data
|
||||
Returns:
|
||||
loss
|
||||
"""
|
||||
states, actions, rewards, dones, next_states = batch
|
||||
|
||||
actions = actions.to(torch.int64)
|
||||
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
|
||||
|
||||
with torch.no_grad():
|
||||
next_state_values = self.target_net(next_states).max(1)[0]
|
||||
next_state_values[dones] = 0.0
|
||||
next_state_values = next_state_values.detach()
|
||||
|
||||
expected_state_action_values = next_state_values * self.gamma + rewards
|
||||
|
||||
return F.mse_loss(state_action_values, expected_state_action_values)
|
||||
|
Reference in New Issue
Block a user