PSELoss.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import torch
  2. from torch import nn
  3. import numpy as np
  4. class PSELoss(nn.Module):
  5. def __init__(self, Lambda, ratio=3, reduction='mean'):
  6. """Implement PSE Loss.
  7. """
  8. super(PSELoss, self).__init__()
  9. assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
  10. self.Lambda = Lambda
  11. self.ratio = ratio
  12. self.reduction = reduction
  13. def forward(self, outputs, labels, training_masks):
  14. texts = outputs[:, -1, :, :]
  15. kernels = outputs[:, :-1, :, :]
  16. gt_texts = labels[:, -1, :, :]
  17. gt_kernels = labels[:, :-1, :, :]
  18. selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
  19. selected_masks = selected_masks.to(outputs.device)
  20. loss_text = self.dice_loss(texts, gt_texts, selected_masks)
  21. loss_kernels = []
  22. mask0 = torch.sigmoid(texts).data.cpu().numpy()
  23. mask1 = training_masks.data.cpu().numpy()
  24. selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
  25. selected_masks = torch.from_numpy(selected_masks).float()
  26. selected_masks = selected_masks.to(outputs.device)
  27. kernels_num = gt_kernels.size()[1]
  28. for i in range(kernels_num):
  29. kernel_i = kernels[:, i, :, :]
  30. gt_kernel_i = gt_kernels[:, i, :, :]
  31. loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
  32. loss_kernels.append(loss_kernel_i)
  33. loss_kernels = torch.stack(loss_kernels).mean(0)
  34. if self.reduction == 'mean':
  35. loss_text = loss_text.mean()
  36. loss_kernels = loss_kernels.mean()
  37. elif self.reduction == 'sum':
  38. loss_text = loss_text.sum()
  39. loss_kernels = loss_kernels.sum()
  40. loss = self.Lambda * loss_text + (1 - self.Lambda) * loss_kernels
  41. return loss_text, loss_kernels, loss
  42. def dice_loss(self, input, target, mask):
  43. input = torch.sigmoid(input)
  44. input = input.contiguous().view(input.size()[0], -1)
  45. target = target.contiguous().view(target.size()[0], -1)
  46. mask = mask.contiguous().view(mask.size()[0], -1)
  47. input = input * mask
  48. target = target * mask
  49. a = torch.sum(input * target, 1)
  50. b = torch.sum(input * input, 1) + 0.001
  51. c = torch.sum(target * target, 1) + 0.001
  52. d = (2 * a) / (b + c)
  53. return 1 - d
  54. def ohem_single(self, score, gt_text, training_mask):
  55. pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
  56. if pos_num == 0:
  57. # selected_mask = gt_text.copy() * 0 # may be not good
  58. selected_mask = training_mask
  59. selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
  60. return selected_mask
  61. neg_num = (int)(np.sum(gt_text <= 0.5))
  62. neg_num = (int)(min(pos_num * 3, neg_num))
  63. if neg_num == 0:
  64. selected_mask = training_mask
  65. selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
  66. return selected_mask
  67. neg_score = score[gt_text <= 0.5]
  68. # 将负样本得分从高到低排序
  69. neg_score_sorted = np.sort(-neg_score)
  70. threshold = -neg_score_sorted[neg_num - 1]
  71. # 选出 得分高的 负样本 和正样本 的 mask
  72. selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
  73. selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
  74. return selected_mask
  75. def ohem_batch(self, scores, gt_texts, training_masks):
  76. scores = scores.data.cpu().numpy()
  77. gt_texts = gt_texts.data.cpu().numpy()
  78. training_masks = training_masks.data.cpu().numpy()
  79. selected_masks = []
  80. for i in range(scores.shape[0]):
  81. selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
  82. selected_masks = np.concatenate(selected_masks, 0)
  83. selected_masks = torch.from_numpy(selected_masks).float()
  84. return selected_masks