76 lines
2.6 KiB
Python
76 lines
2.6 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class FocalLoss(nn.modules.loss._WeightedLoss):
|
|
def __init__(self, weight=None, gamma=2,reduction='mean'):
|
|
super(FocalLoss, self).__init__(weight,reduction=reduction)
|
|
self.gamma = gamma
|
|
self.weight = weight # weight parameter will act as the alpha parameter to balance class weights
|
|
|
|
def forward(self, input, target):
|
|
|
|
ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight)
|
|
pt = torch.exp(-ce_loss)
|
|
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
|
|
return focal_loss
|
|
|
|
|
|
class FocalLossRob(nn.Module):
|
|
# taken from https://github.com/mathiaszinnen/focal_loss_torch/blob/main/focal_loss/focal_loss.py
|
|
def __init__(self, alpha=1, gamma=2, reduction: str = 'mean'):
|
|
super().__init__()
|
|
if reduction not in ['mean', 'none', 'sum']:
|
|
raise NotImplementedError('Reduction {} not implemented.'.format(reduction))
|
|
self.reduction = reduction
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
|
|
def forward(self, x, target):
|
|
x = x.clamp(1e-7, 1. - 1e-7) # own addition
|
|
p_t = torch.where(target == 1, x, 1-x)
|
|
fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t)
|
|
fl = torch.where(target == 1, fl * self.alpha, fl)
|
|
return self._reduce(fl)
|
|
|
|
def _reduce(self, x):
|
|
if self.reduction == 'mean':
|
|
return x.mean()
|
|
elif self.reduction == 'sum':
|
|
return x.sum()
|
|
else:
|
|
return x
|
|
|
|
|
|
class DQN_MSELoss(object):
|
|
|
|
def __init__(self, agent_net, target_net, gamma):
|
|
self.agent_net = agent_net
|
|
self.target_net = target_net
|
|
self.gamma = gamma
|
|
|
|
def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
|
"""
|
|
Calculates the mse loss using a mini batch from the replay buffer
|
|
Args:
|
|
batch: current mini batch of replay data
|
|
Returns:
|
|
loss
|
|
"""
|
|
states, actions, rewards, dones, next_states = batch
|
|
|
|
actions = actions.to(torch.int64)
|
|
state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
|
|
|
|
with torch.no_grad():
|
|
next_state_values = self.target_net(next_states).max(1)[0]
|
|
next_state_values[dones] = 0.0
|
|
next_state_values = next_state_values.detach()
|
|
|
|
expected_state_action_values = next_state_values * self.gamma + rewards
|
|
|
|
return F.mse_loss(state_action_values, expected_state_action_values)
|