DetBasicLoss.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2019/12/4 14:39
  3. # @Author : zhoujun
  4. import torch
  5. import torch.nn as nn
  6. class BalanceCrossEntropyLoss(nn.Module):
  7. '''
  8. Balanced cross entropy loss.
  9. Shape:
  10. - Input: :math:`(N, 1, H, W)`
  11. - GT: :math:`(N, 1, H, W)`, same shape as the input
  12. - Mask: :math:`(N, H, W)`, same spatial shape as the input
  13. - Output: scalar.
  14. Examples::
  15. >>> m = nn.Sigmoid()
  16. >>> loss = nn.BCELoss()
  17. >>> input = torch.randn(3, requires_grad=True)
  18. >>> target = torch.empty(3).random_(2)
  19. >>> output = loss(m(input), target)
  20. >>> output.backward()
  21. '''
  22. def __init__(self, negative_ratio=3.0, eps=1e-6):
  23. super(BalanceCrossEntropyLoss, self).__init__()
  24. self.negative_ratio = negative_ratio
  25. self.eps = eps
  26. def forward(self,
  27. pred: torch.Tensor,
  28. gt: torch.Tensor,
  29. mask: torch.Tensor,
  30. return_origin=False):
  31. '''
  32. Args:
  33. pred: shape :math:`(N, 1, H, W)`, the prediction of network
  34. gt: shape :math:`(N, 1, H, W)`, the target
  35. mask: shape :math:`(N, H, W)`, the mask indicates positive regions
  36. '''
  37. positive = (gt * mask).byte()
  38. negative = ((1 - gt) * mask).byte()
  39. positive_count = int(positive.float().sum())
  40. negative_count = min(int(negative.float().sum()), int(positive_count * self.negative_ratio))
  41. loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')
  42. positive_loss = loss * positive.float()
  43. negative_loss = loss * negative.float()
  44. negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
  45. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + self.eps)
  46. if return_origin:
  47. return balance_loss, loss
  48. return balance_loss
  49. class DiceLoss(nn.Module):
  50. '''
  51. Loss function from https://arxiv.org/abs/1707.03237,
  52. where iou computation is introduced heatmap manner to measure the
  53. diversity bwtween tow heatmaps.
  54. '''
  55. def __init__(self, eps=1e-6):
  56. super(DiceLoss, self).__init__()
  57. self.eps = eps
  58. def forward(self, pred: torch.Tensor, gt, mask, weights=None):
  59. '''
  60. pred: one or two heatmaps of shape (N, 1, H, W),
  61. the losses of tow heatmaps are added together.
  62. gt: (N, 1, H, W)
  63. mask: (N, H, W)
  64. '''
  65. return self._compute(pred, gt, mask, weights)
  66. def _compute(self, pred, gt, mask, weights):
  67. if pred.dim() == 4:
  68. pred = pred[:, 0, :, :]
  69. gt = gt[:, 0, :, :]
  70. assert pred.shape == gt.shape
  71. assert pred.shape == mask.shape
  72. if weights is not None:
  73. assert weights.shape == mask.shape
  74. mask = weights * mask
  75. intersection = (pred * gt * mask).sum()
  76. union = (pred * mask).sum() + (gt * mask).sum() + self.eps
  77. loss = 1 - 2.0 * intersection / union
  78. assert loss <= 1
  79. return loss
  80. class MaskL1Loss(nn.Module):
  81. def __init__(self, eps=1e-6):
  82. super(MaskL1Loss, self).__init__()
  83. self.eps = eps
  84. def forward(self, pred: torch.Tensor, gt, mask):
  85. loss = (torch.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
  86. return loss
  87. class BCELoss(nn.Module):
  88. def __init__(self, reduction='mean'):
  89. super(BCELoss, self).__init__()
  90. self.reduction = reduction
  91. def forward(self, input, label, mask=None, weight=None, name=None):
  92. loss = nn.functional.binary_cross_entropy(input, label, reduction=self.reduction)
  93. return loss
  94. class BalanceLoss(nn.Module):
  95. def __init__(self,
  96. balance_loss=True,
  97. main_loss_type='DiceLoss',
  98. negative_ratio=3,
  99. return_origin=False,
  100. eps=1e-6,
  101. **kwargs):
  102. super(BalanceLoss, self).__init__()
  103. self.balance_loss = balance_loss
  104. self.main_loss_type = main_loss_type
  105. self.negative_ratio = negative_ratio
  106. self.return_origin = return_origin
  107. self.eps = eps
  108. if self.main_loss_type == "CrossEntropy":
  109. self.loss = nn.CrossEntropyLoss()
  110. elif self.main_loss_type == "Euclidean":
  111. self.loss = nn.MSELoss()
  112. elif self.main_loss_type == "DiceLoss":
  113. self.loss = DiceLoss(self.eps)
  114. elif self.main_loss_type == "BCELoss":
  115. self.loss = BCELoss(reduction='none')
  116. elif self.main_loss_type == "MaskL1Loss":
  117. self.loss = MaskL1Loss(self.eps)
  118. else:
  119. loss_type = [
  120. 'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
  121. ]
  122. raise Exception(
  123. "main_loss_type in BalanceLoss() can only be one of {}".format(
  124. loss_type))
  125. def forward(self, pred, gt, mask=None):
  126. """
  127. The BalanceLoss for Differentiable Binarization text detection
  128. args:
  129. pred (variable): predicted feature maps.
  130. gt (variable): ground truth feature maps.
  131. mask (variable): masked maps.
  132. return: (variable) balanced loss
  133. """
  134. positive = (gt * mask).byte()
  135. negative = ((1 - gt) * mask).byte()
  136. positive_count = int(positive.float().sum())
  137. negative_count = int(min(negative.float().sum(), positive_count * self.negative_ratio))
  138. loss = self.loss(pred, gt, mask=mask)
  139. if not self.balance_loss:
  140. return loss
  141. positive_loss = positive.float() * loss
  142. negative_loss = negative.float() * loss
  143. if negative_count > 0:
  144. negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
  145. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
  146. positive_count + negative_count + self.eps)
  147. else:
  148. balance_loss = positive_loss.sum() / (positive_count + self.eps)
  149. if self.return_origin:
  150. return balance_loss, loss
  151. return balance_loss