FCEPostProcess.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """
  2. This code is refer from:
  3. https://github.com/open-mmlab/mmocr/blob/v0.3.0/mmocr/models/textdet/postprocess/wrapper.py
  4. """
  5. import cv2
  6. import torch
  7. import numpy as np
  8. from numpy.fft import ifft
  9. from torchocr.utils.poly_nms import poly_nms, valid_boundary
  10. def fill_hole(input_mask):
  11. h, w = input_mask.shape
  12. canvas = np.zeros((h + 2, w + 2), np.uint8)
  13. canvas[1:h + 1, 1:w + 1] = input_mask.copy()
  14. mask = np.zeros((h + 4, w + 4), np.uint8)
  15. cv2.floodFill(canvas, mask, (0, 0), 1)
  16. canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
  17. return ~canvas | input_mask
  18. def fourier2poly(fourier_coeff, num_reconstr_points=50):
  19. """ Inverse Fourier transform
  20. Args:
  21. fourier_coeff (ndarray): Fourier coefficients shaped (n, 2k+1),
  22. with n and k being candidates number and Fourier degree
  23. respectively.
  24. num_reconstr_points (int): Number of reconstructed polygon points.
  25. Returns:
  26. Polygons (ndarray): The reconstructed polygons shaped (n, n')
  27. """
  28. a = np.zeros((len(fourier_coeff), num_reconstr_points), dtype='complex')
  29. k = (len(fourier_coeff[0]) - 1) // 2
  30. a[:, 0:k + 1] = fourier_coeff[:, k:]
  31. a[:, -k:] = fourier_coeff[:, :k]
  32. poly_complex = ifft(a) * num_reconstr_points
  33. polygon = np.zeros((len(fourier_coeff), num_reconstr_points, 2))
  34. polygon[:, :, 0] = poly_complex.real
  35. polygon[:, :, 1] = poly_complex.imag
  36. return polygon.astype('int32').reshape((len(fourier_coeff), -1))
  37. class FCEPostProcess(object):
  38. """
  39. The post process for FCENet.
  40. """
  41. def __init__(self,
  42. scales,
  43. fourier_degree=5,
  44. num_reconstr_points=50,
  45. decoding_type='fcenet',
  46. score_thr=0.3,
  47. nms_thr=0.1,
  48. alpha=1.0,
  49. beta=1.0,
  50. box_type='poly',
  51. **kwargs):
  52. self.scales = scales
  53. self.fourier_degree = fourier_degree
  54. self.num_reconstr_points = num_reconstr_points
  55. self.decoding_type = decoding_type
  56. self.score_thr = score_thr
  57. self.nms_thr = nms_thr
  58. self.alpha = alpha
  59. self.beta = beta
  60. self.box_type = box_type
  61. def __call__(self, preds, shape_list):
  62. score_maps = []
  63. for key, value in preds.items():
  64. if isinstance(value, torch.Tensor):
  65. value = value.cpu().detach().numpy()
  66. cls_res = value[:, :4, :, :]
  67. reg_res = value[:, 4:, :, :]
  68. score_maps.append([cls_res, reg_res])
  69. return self.get_boundary(score_maps, shape_list)
  70. def resize_boundary(self, boundaries, scale_factor):
  71. """Rescale boundaries via scale_factor.
  72. Args:
  73. boundaries (list[list[float]]): The boundary list. Each boundary
  74. with size 2k+1 with k>=4.
  75. scale_factor(ndarray): The scale factor of size (4,).
  76. Returns:
  77. boundaries (list[list[float]]): The scaled boundaries.
  78. """
  79. boxes = []
  80. scores = []
  81. for b in boundaries:
  82. sz = len(b)
  83. valid_boundary(b, True)
  84. scores.append(b[-1])
  85. b = (np.array(b[:sz - 1]) *
  86. (np.tile(scale_factor[:2], int(
  87. (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
  88. boxes.append(np.array(b).reshape([-1, 2]))
  89. return np.array(boxes, dtype=np.float32), scores
  90. def get_boundary(self, score_maps, shape_list):
  91. assert len(score_maps) == len(self.scales)
  92. boundaries = []
  93. for idx, score_map in enumerate(score_maps):
  94. scale = self.scales[idx]
  95. boundaries = boundaries + self._get_boundary_single(score_map,scale)
  96. # nms
  97. boundaries = poly_nms(boundaries, self.nms_thr)
  98. boundaries, scores = self.resize_boundary(
  99. boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
  100. # boxes_batch = [dict(points=boundaries, scores=scores)]
  101. return boundaries.tolist(),scores
  102. def _get_boundary_single(self, score_map, scale):
  103. assert len(score_map) == 2
  104. assert score_map[1].shape[1] == 4 * self.fourier_degree + 2
  105. return self.fcenet_decode(
  106. preds=score_map,
  107. fourier_degree=self.fourier_degree,
  108. num_reconstr_points=self.num_reconstr_points,
  109. scale=scale,
  110. alpha=self.alpha,
  111. beta=self.beta,
  112. box_type=self.box_type,
  113. score_thr=self.score_thr,
  114. nms_thr=self.nms_thr)
  115. def fcenet_decode(self,
  116. preds,
  117. fourier_degree,
  118. num_reconstr_points,
  119. scale,
  120. alpha=1.0,
  121. beta=2.0,
  122. box_type='poly',
  123. score_thr=0.3,
  124. nms_thr=0.1):
  125. """Decoding predictions of FCENet to instances.
  126. Args:
  127. preds (list(Tensor)): The head output tensors.
  128. fourier_degree (int): The maximum Fourier transform degree k.
  129. num_reconstr_points (int): The points number of the polygon
  130. reconstructed from predicted Fourier coefficients.
  131. scale (int): The down-sample scale of the prediction.
  132. alpha (float) : The parameter to calculate final scores. Score_{final}
  133. = (Score_{text region} ^ alpha)
  134. * (Score_{text center region}^ beta)
  135. beta (float) : The parameter to calculate final score.
  136. box_type (str): Boundary encoding type 'poly' or 'quad'.
  137. score_thr (float) : The threshold used to filter out the final
  138. candidates.
  139. nms_thr (float) : The threshold of nms.
  140. Returns:
  141. boundaries (list[list[float]]): The instance boundary and confidence
  142. list.
  143. """
  144. assert isinstance(preds, list)
  145. assert len(preds) == 2
  146. assert box_type in ['poly', 'quad']
  147. cls_pred = preds[0][0]
  148. tr_pred = cls_pred[0:2]
  149. tcl_pred = cls_pred[2:]
  150. reg_pred = preds[1][0].transpose([1, 2, 0])
  151. x_pred = reg_pred[:, :, :2 * fourier_degree + 1]
  152. y_pred = reg_pred[:, :, 2 * fourier_degree + 1:]
  153. score_pred = (tr_pred[1]**alpha) * (tcl_pred[1]**beta)
  154. tr_pred_mask = (score_pred) > score_thr
  155. tr_mask = fill_hole(tr_pred_mask)
  156. tr_contours, _ = cv2.findContours(
  157. tr_mask.astype(np.uint8), cv2.RETR_TREE,
  158. cv2.CHAIN_APPROX_SIMPLE) # opencv4
  159. mask = np.zeros_like(tr_mask)
  160. boundaries = []
  161. for cont in tr_contours:
  162. deal_map = mask.copy().astype(np.int8)
  163. cv2.drawContours(deal_map, [cont], -1, 1, -1)
  164. score_map = score_pred * deal_map
  165. score_mask = score_map > 0
  166. xy_text = np.argwhere(score_mask)
  167. dxy = xy_text[:, 1] + xy_text[:, 0] * 1j
  168. x, y = x_pred[score_mask], y_pred[score_mask]
  169. c = x + y * 1j
  170. c[:, fourier_degree] = c[:, fourier_degree] + dxy
  171. c *= scale
  172. polygons = fourier2poly(c, num_reconstr_points)
  173. score = score_map[score_mask].reshape(-1, 1)
  174. polygons = poly_nms(np.hstack((polygons, score)).tolist(), nms_thr)
  175. boundaries = boundaries + polygons
  176. boundaries = poly_nms(boundaries, nms_thr)
  177. if box_type == 'quad':
  178. new_boundaries = []
  179. for boundary in boundaries:
  180. poly = np.array(boundary[:-1]).reshape(-1, 2).astype(np.float32)
  181. score = boundary[-1]
  182. points = cv2.boxPoints(cv2.minAreaRect(poly))
  183. points = np.int0(points)
  184. new_boundaries.append(points.reshape(-1).tolist() + [score])
  185. boundaries = new_boundaries
  186. return boundaries