123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- # -*- coding: utf-8 -*-
- # @Time : 2019/12/4 14:39
- # @Author : zhoujun
- import torch
- import torch.nn as nn
- class BalanceCrossEntropyLoss(nn.Module):
- '''
- Balanced cross entropy loss.
- Shape:
- - Input: :math:`(N, 1, H, W)`
- - GT: :math:`(N, 1, H, W)`, same shape as the input
- - Mask: :math:`(N, H, W)`, same spatial shape as the input
- - Output: scalar.
- Examples::
- >>> m = nn.Sigmoid()
- >>> loss = nn.BCELoss()
- >>> input = torch.randn(3, requires_grad=True)
- >>> target = torch.empty(3).random_(2)
- >>> output = loss(m(input), target)
- >>> output.backward()
- '''
- def __init__(self, negative_ratio=3.0, eps=1e-6):
- super(BalanceCrossEntropyLoss, self).__init__()
- self.negative_ratio = negative_ratio
- self.eps = eps
- def forward(self,
- pred: torch.Tensor,
- gt: torch.Tensor,
- mask: torch.Tensor,
- return_origin=False):
- '''
- Args:
- pred: shape :math:`(N, 1, H, W)`, the prediction of network
- gt: shape :math:`(N, 1, H, W)`, the target
- mask: shape :math:`(N, H, W)`, the mask indicates positive regions
- '''
- positive = (gt * mask).byte()
- negative = ((1 - gt) * mask).byte()
- positive_count = int(positive.float().sum())
- negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
- loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
- positive_loss = loss * positive.float()
- negative_loss = loss * negative.float()
- negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
- balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
- if return_origin:
- return balance_loss, loss
- return balance_loss
- class DiceLoss(nn.Module):
- '''
- Loss function from https://arxiv.org/abs/1707.03237,
- where iou computation is introduced heatmap manner to measure the
- diversity bwtween tow heatmaps.
- '''
- def __init__(self, eps=1e-6):
- super(DiceLoss, self).__init__()
- self.eps = eps
- def forward(self, pred: torch.Tensor, gt, mask, weights=None):
- '''
- pred: one or two heatmaps of shape (N, 1, H, W),
- the losses of tow heatmaps are added together.
- gt: (N, 1, H, W)
- mask: (N, H, W)
- '''
- return self._compute(pred, gt, mask, weights)
- def _compute(self, pred, gt, mask, weights):
- if pred.dim() == 4:
- pred = pred[:, 0, :, :]
- gt = gt[:, 0, :, :]
- assert pred.shape == gt.shape
- assert pred.shape == mask.shape
- if weights is not None:
- assert weights.shape == mask.shape
- mask = weights * mask
- intersection = (pred * gt * mask).sum()
- union = (pred * mask).sum() + (gt * mask).sum() + self.eps
- loss = 1 - 2.0 * intersection / union
- assert loss <= 1
- return loss
- class MaskL1Loss(nn.Module):
- def __init__(self, eps=1e-6):
- super(MaskL1Loss, self).__init__()
- self.eps = eps
- def forward(self, pred: torch.Tensor, gt, mask):
- loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
- return loss
- class BCELoss(nn.Module):
- def __init__(self, reduction='mean'):
- super(BCELoss, self).__init__()
- self.reduction = reduction
- def forward(self, input, label, mask=None, weight=None, name=None):
- loss = nn.functional.binary_cross_entropy(input, label, reduction=self.reduction)
- return loss
- class BalanceLoss(nn.Module):
- def __init__(self,
- balance_loss=True,
- main_loss_type='DiceLoss',
- negative_ratio=3,
- return_origin=False,
- eps=1e-6,
- **kwargs):
- super(BalanceLoss, self).__init__()
- self.balance_loss = balance_loss
- self.main_loss_type = main_loss_type
- self.negative_ratio = negative_ratio
- self.return_origin = return_origin
- self.eps = eps
- if self.main_loss_type == "CrossEntropy":
- self.loss = nn.CrossEntropyLoss()
- elif self.main_loss_type == "Euclidean":
- self.loss = nn.MSELoss()
- elif self.main_loss_type == "DiceLoss":
- self.loss = DiceLoss(self.eps)
- elif self.main_loss_type == "BCELoss":
- self.loss = BCELoss(reduction='none')
- elif self.main_loss_type == "MaskL1Loss":
- self.loss = MaskL1Loss(self.eps)
- else:
- loss_type = [
- 'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
- ]
- raise Exception(
- "main_loss_type in BalanceLoss() can only be one of {}".format(
- loss_type))
- def forward(self, pred, gt, mask=None):
- """
- The BalanceLoss for Differentiable Binarization text detection
- args:
- pred (variable): predicted feature maps.
- gt (variable): ground truth feature maps.
- mask (variable): masked maps.
- return: (variable) balanced loss
- """
- positive = (gt * mask).byte()
- negative = ((1 - gt) * mask).byte()
- positive_count = int(positive.float().sum())
- negative_count = int(min(negative.float().sum(), positive_count * self.negative_ratio))
- loss = self.loss(pred, gt, mask=mask)
- if not self.balance_loss:
- return loss
- positive_loss = positive.float() * loss
- negative_loss = negative.float() * loss
- if negative_count > 0:
- negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
- balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
- positive_count + negative_count + self.eps)
- else:
- balance_loss = positive_loss.sum() / (positive_count + self.eps)
- if self.return_origin:
- return balance_loss, loss
- return balance_loss
|