Code Comments, Getting Dirty Env, Naming

This commit is contained in:
Steffen Illium
2021-05-11 10:31:34 +02:00
parent faa27c3cf9
commit ab01006eae
7 changed files with 51 additions and 16 deletions

View File

@ -1,3 +1,5 @@
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -41,3 +43,33 @@ class FocalLossRob(nn.Module):
return x.sum()
else:
return x
class DQN_MSELoss(object):
def __init__(self, agent_net, target_net, gamma):
self.agent_net = agent_net
self.target_net = target_net
self.gamma = gamma
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
"""
Calculates the mse loss using a mini batch from the replay buffer
Args:
batch: current mini batch of replay data
Returns:
loss
"""
states, actions, rewards, dones, next_states = batch
actions = actions.to(torch.int64)
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()
expected_state_action_values = next_state_values * self.gamma + rewards
return F.mse_loss(state_action_values, expected_state_action_values)