FCELoss.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import numpy as np
  2. from torch import nn
  3. import torch
  4. import torch.nn.functional as F
  5. from functools import partial
  6. def multi_apply(func, *args, **kwargs):
  7. pfunc = partial(func, **kwargs) if kwargs else func
  8. map_results = map(pfunc, *args)
  9. return tuple(map(list, zip(*map_results)))
  10. class FCELoss(nn.Module):
  11. """The class for implementing FCENet loss
  12. FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
  13. Text Detection
  14. [https://arxiv.org/abs/2104.10442]
  15. Args:
  16. fourier_degree (int) : The maximum Fourier transform degree k.
  17. num_sample (int) : The sampling points number of regression
  18. loss. If it is too small, fcenet tends to be overfitting.
  19. ohem_ratio (float): the negative/positive ratio in OHEM.
  20. """
  21. def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
  22. super().__init__()
  23. self.fourier_degree = fourier_degree
  24. self.num_sample = num_sample
  25. self.ohem_ratio = ohem_ratio
  26. def forward(self, preds, labels):
  27. assert isinstance(preds, dict)
  28. preds = preds['levels']
  29. p3_maps, p4_maps, p5_maps = labels['p3_maps'], labels['p4_maps'], labels['p5_maps']
  30. assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5, \
  31. 'fourier degree not equal in FCEhead and FCEtarget'
  32. # to tensor
  33. gts = [p3_maps, p4_maps, p5_maps]
  34. # for idx, maps in enumerate(gts):
  35. # gts[idx] = torch.tensor(np.stack(maps.cpu().detach().numpy()))
  36. # torch.stack(maps)
  37. losses = multi_apply(self.forward_single, preds, gts)
  38. loss_tr = torch.tensor(0.).cuda().float()
  39. loss_tcl = torch.tensor(0.).cuda().float()
  40. loss_reg_x = torch.tensor(0.).cuda().float()
  41. loss_reg_y = torch.tensor(0.).cuda().float()
  42. loss_all = torch.tensor(0.).cuda().float()
  43. for idx, loss in enumerate(losses):
  44. loss_all += sum(loss)
  45. if idx == 0:
  46. loss_tr += sum(loss)
  47. elif idx == 1:
  48. loss_tcl += sum(loss)
  49. elif idx == 2:
  50. loss_reg_x += sum(loss)
  51. else:
  52. loss_reg_y += sum(loss)
  53. results = dict(
  54. loss=loss_all,
  55. loss_text=loss_tr,
  56. loss_center=loss_tcl,
  57. loss_reg_x=loss_reg_x,
  58. loss_reg_y=loss_reg_y, )
  59. return results
  60. def forward_single(self, pred, gt):
  61. cls_pred = pred[0].permute(0, 2, 3, 1)
  62. reg_pred = pred[1].permute(0, 2, 3, 1)
  63. gt = gt.permute(0, 2, 3, 1)
  64. k = 2 * self.fourier_degree + 1
  65. tr_pred = torch.reshape(cls_pred[:, :, :, :2], (-1, 2))
  66. tcl_pred = torch.reshape(cls_pred[:, :, :, 2:], (-1, 2))
  67. x_pred = torch.reshape(reg_pred[:, :, :, 0:k], (-1, k))
  68. y_pred = torch.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
  69. tr_mask = gt[:, :, :, :1].reshape([-1])
  70. tcl_mask = gt[:, :, :, 1:2].reshape([-1])
  71. train_mask = gt[:, :, :, 2:3].reshape([-1])
  72. x_map = torch.reshape(gt[:, :, :, 3:3 + k], (-1, k))
  73. y_map = torch.reshape(gt[:, :, :, 3 + k:], (-1, k))
  74. tr_train_mask = (train_mask * tr_mask).bool()
  75. tr_train_mask2 = torch.cat(
  76. [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], dim=1)
  77. # tr loss
  78. loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
  79. # tcl loss
  80. loss_tcl = torch.tensor((0.), dtype=torch.float32)
  81. tr_neg_mask = tr_train_mask.logical_not()
  82. tr_neg_mask2 = torch.cat(
  83. [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], dim=1)
  84. if tr_train_mask.sum().item() > 0:
  85. loss_tcl_pos = F.cross_entropy(
  86. tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
  87. tcl_mask.masked_select(tr_train_mask).long())
  88. loss_tcl_neg = F.cross_entropy(
  89. tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
  90. tcl_mask.masked_select(tr_neg_mask).long())
  91. loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
  92. # regression loss
  93. loss_reg_x = torch.tensor(0.).float()
  94. loss_reg_y = torch.tensor(0.).float()
  95. if tr_train_mask.sum().item() > 0:
  96. weight = (tr_mask.masked_select(tr_train_mask.bool())
  97. .float() + tcl_mask.masked_select(
  98. tr_train_mask.bool()).float()) / 2
  99. weight = weight.reshape([-1, 1])
  100. ft_x, ft_y = self.fourier2poly(x_map, y_map)
  101. ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
  102. dim = ft_x.shape[1]
  103. tr_train_mask3 = torch.cat(
  104. [tr_train_mask.unsqueeze(1) for i in range(dim)], dim=1)
  105. loss_reg_x = torch.mean(weight * F.smooth_l1_loss(
  106. ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
  107. ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
  108. reduction='none'))
  109. loss_reg_y = torch.mean(weight * F.smooth_l1_loss(
  110. ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
  111. ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
  112. reduction='none'))
  113. return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
  114. def ohem(self, predict, target, train_mask):
  115. pos = (target * train_mask).bool()
  116. neg = ((1 - target) * train_mask).bool()
  117. pos2 = torch.cat([pos.unsqueeze(1), pos.unsqueeze(1)], dim=1)
  118. neg2 = torch.cat([neg.unsqueeze(1), neg.unsqueeze(1)], dim=1)
  119. n_pos = pos.float().sum()
  120. if n_pos.item() > 0:
  121. loss_pos = F.cross_entropy(
  122. predict.masked_select(pos2).reshape([-1, 2]),
  123. target.masked_select(pos).long(),
  124. reduction='sum')
  125. loss_neg = F.cross_entropy(
  126. predict.masked_select(neg2).reshape([-1, 2]),
  127. target.masked_select(neg).long(),
  128. reduction='none')
  129. n_neg = min(
  130. int(neg.float().sum().item()),
  131. int(self.ohem_ratio * n_pos.float()))
  132. else:
  133. loss_pos = torch.tensor(0.)
  134. loss_neg = F.cross_entropy(
  135. predict.masked_select(neg2).reshape([-1, 2]),
  136. target.masked_select(neg).long(),
  137. reduction='none')
  138. n_neg = 100
  139. if len(loss_neg) > n_neg:
  140. loss_neg, _ = torch.topk(loss_neg, n_neg)
  141. return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
  142. def fourier2poly(self, real_maps, imag_maps):
  143. """Transform Fourier coefficient maps to polygon maps.
  144. Args:
  145. real_maps (tensor): A map composed of the real parts of the
  146. Fourier coefficients, whose shape is (-1, 2k+1)
  147. imag_maps (tensor):A map composed of the imag parts of the
  148. Fourier coefficients, whose shape is (-1, 2k+1)
  149. Returns
  150. x_maps (tensor): A map composed of the x value of the polygon
  151. represented by n sample points (xn, yn), whose shape is (-1, n)
  152. y_maps (tensor): A map composed of the y value of the polygon
  153. represented by n sample points (xn, yn), whose shape is (-1, n)
  154. """
  155. k_vect = torch.arange(
  156. -self.fourier_degree, self.fourier_degree + 1,
  157. dtype=torch.float32).reshape([-1, 1])
  158. i_vect = torch.arange(
  159. 0, self.num_sample, dtype=torch.float32).reshape([1, -1])
  160. transform_matrix = 2 * np.pi / self.num_sample * torch.matmul(k_vect,
  161. i_vect)
  162. x1 = torch.einsum('ak, kn-> an', real_maps,
  163. torch.cos(transform_matrix).cuda())
  164. x2 = torch.einsum('ak, kn-> an', imag_maps,
  165. torch.sin(transform_matrix).cuda())
  166. y1 = torch.einsum('ak, kn-> an', real_maps,
  167. torch.sin(transform_matrix).cuda())
  168. y2 = torch.einsum('ak, kn-> an', imag_maps,
  169. torch.cos(transform_matrix).cuda())
  170. x_maps = x1 - x2
  171. y_maps = y1 + y2
  172. return x_maps, y_maps