# -*- coding: utf-8 -*- # @Time : 2019/12/4 14:39 # @Author : zhoujun import torch import torch.nn as nn class BalanceCrossEntropyLoss(nn.Module): ''' Balanced cross entropy loss. Shape: - Input: :math:`(N, 1, H, W)` - GT: :math:`(N, 1, H, W)`, same shape as the input - Mask: :math:`(N, H, W)`, same spatial shape as the input - Output: scalar. Examples:: >>> m = nn.Sigmoid() >>> loss = nn.BCELoss() >>> input = torch.randn(3, requires_grad=True) >>> target = torch.empty(3).random_(2) >>> output = loss(m(input), target) >>> output.backward() ''' def __init__(self, negative_ratio=3.0, eps=1e-6): super(BalanceCrossEntropyLoss, self).__init__() self.negative_ratio = negative_ratio self.eps = eps def forward(self, pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, return_origin=False): ''' Args: pred: shape :math:`(N, 1, H, W)`, the prediction of network gt: shape :math:`(N, 1, H, W)`, the target mask: shape :math:`(N, H, W)`, the mask indicates positive regions ''' positive = (gt * mask).byte() negative = ((1 - gt) * mask).byte() positive_count = int(positive.float().sum()) negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio)) loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none') positive_loss = loss * positive.float() negative_loss = loss * negative.float() negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps) if return_origin: return balance_loss, loss return balance_loss class DiceLoss(nn.Module): ''' Loss function from https://arxiv.org/abs/1707.03237, where iou computation is introduced heatmap manner to measure the diversity bwtween tow heatmaps. ''' def __init__(self, eps=1e-6): super(DiceLoss, self).__init__() self.eps = eps def forward(self, pred: torch.Tensor, gt, mask, weights=None): ''' pred: one or two heatmaps of shape (N, 1, H, W), the losses of tow heatmaps are added together. gt: (N, 1, H, W) mask: (N, H, W) ''' return self._compute(pred, gt, mask, weights) def _compute(self, pred, gt, mask, weights): if pred.dim() == 4: pred = pred[:, 0, :, :] gt = gt[:, 0, :, :] assert pred.shape == gt.shape assert pred.shape == mask.shape if weights is not None: assert weights.shape == mask.shape mask = weights * mask intersection = (pred * gt * mask).sum() union = (pred * mask).sum() + (gt * mask).sum() + self.eps loss = 1 - 2.0 * intersection / union assert loss <= 1 return loss class MaskL1Loss(nn.Module): def __init__(self, eps=1e-6): super(MaskL1Loss, self).__init__() self.eps = eps def forward(self, pred: torch.Tensor, gt, mask): loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps) return loss class BCELoss(nn.Module): def __init__(self, reduction='mean'): super(BCELoss, self).__init__() self.reduction = reduction def forward(self, input, label, mask=None, weight=None, name=None): loss = nn.functional.binary_cross_entropy(input, label, reduction=self.reduction) return loss class BalanceLoss(nn.Module): def __init__(self, balance_loss=True, main_loss_type='DiceLoss', negative_ratio=3, return_origin=False, eps=1e-6, **kwargs): super(BalanceLoss, self).__init__() self.balance_loss = balance_loss self.main_loss_type = main_loss_type self.negative_ratio = negative_ratio self.return_origin = return_origin self.eps = eps if self.main_loss_type == "CrossEntropy": self.loss = nn.CrossEntropyLoss() elif self.main_loss_type == "Euclidean": self.loss = nn.MSELoss() elif self.main_loss_type == "DiceLoss": self.loss = DiceLoss(self.eps) elif self.main_loss_type == "BCELoss": self.loss = BCELoss(reduction='none') elif self.main_loss_type == "MaskL1Loss": self.loss = MaskL1Loss(self.eps) else: loss_type = [ 'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss' ] raise Exception( "main_loss_type in BalanceLoss() can only be one of {}".format( loss_type)) def forward(self, pred, gt, mask=None): """ The BalanceLoss for Differentiable Binarization text detection args: pred (variable): predicted feature maps. gt (variable): ground truth feature maps. mask (variable): masked maps. return: (variable) balanced loss """ positive = (gt * mask).byte() negative = ((1 - gt) * mask).byte() positive_count = int(positive.float().sum()) negative_count = int(min(negative.float().sum(), positive_count * self.negative_ratio)) loss = self.loss(pred, gt, mask=mask) if not self.balance_loss: return loss positive_loss = positive.float() * loss negative_loss = negative.float() * loss if negative_count > 0: negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( positive_count + negative_count + self.eps) else: balance_loss = positive_loss.sum() / (positive_count + self.eps) if self.return_origin: return balance_loss, loss return balance_loss