from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.modules.loss._WeightedLoss): def __init__(self, weight=None, gamma=2,reduction='mean'): super(FocalLoss, self).__init__(weight,reduction=reduction) self.gamma = gamma self.weight = weight # weight parameter will act as the alpha parameter to balance class weights def forward(self, input, target): ce_loss = F.cross_entropy(input, target, reduction=self.reduction, weight=self.weight) pt = torch.exp(-ce_loss) focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean() return focal_loss class FocalLossRob(nn.Module): # taken from https://github.com/mathiaszinnen/focal_loss_torch/blob/main/focal_loss/focal_loss.py def __init__(self, alpha=1, gamma=2, reduction: str = 'mean'): super().__init__() if reduction not in ['mean', 'none', 'sum']: raise NotImplementedError('Reduction {} not implemented.'.format(reduction)) self.reduction = reduction self.alpha = alpha self.gamma = gamma def forward(self, x, target): x = x.clamp(1e-7, 1. - 1e-7) # own addition p_t = torch.where(target == 1, x, 1-x) fl = - 1 * (1 - p_t) ** self.gamma * torch.log(p_t) fl = torch.where(target == 1, fl * self.alpha, fl) return self._reduce(fl) def _reduce(self, x): if self.reduction == 'mean': return x.mean() elif self.reduction == 'sum': return x.sum() else: return x class DQN_MSELoss(object): def __init__(self, agent_net, target_net, gamma): self.agent_net = agent_net self.target_net = target_net self.gamma = gamma def __call__(self, batch: Tuple[torch.Tensor, ...]) -> torch.Tensor: """ Calculates the mse loss using a mini batch from the replay buffer Args: batch: current mini batch of replay data Returns: loss """ states, actions, rewards, dones, next_states = batch actions = actions.to(torch.int64) state_action_values = self.agent_net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) with torch.no_grad(): next_state_values = self.target_net(next_states).max(1)[0] next_state_values[dones] = 0.0 next_state_values = next_state_values.detach() expected_state_action_values = next_state_values * self.gamma + rewards return F.mse_loss(state_action_values, expected_state_action_values)