DBPostProcess.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import cv2
  2. import numpy as np
  3. import pyclipper
  4. from shapely.geometry import Polygon
  5. from pyclipper import PyclipperOffset
  6. import math
  7. import operator
  8. from functools import reduce
  9. def clockwise_sort_points(_point_coordinates):
  10. """
  11. 以左上角为起点的顺时针排序
  12. 原理就是将笛卡尔坐标转换为极坐标,然后对极坐标的φ进行排序
  13. Args:
  14. _point_coordinates: 待排序的点[(x,y),]
  15. Returns: 排序完成的点
  16. """
  17. center_point = tuple(
  18. map(operator.truediv, reduce(lambda x, y: map(operator.add, x, y), _point_coordinates),
  19. [len(_point_coordinates)] * 2))
  20. return sorted(_point_coordinates, key=lambda coord: (180 + math.degrees(
  21. math.atan2(*tuple(map(operator.sub, coord, center_point))[::-1]))) % 360)
  22. class DistillationDBPostProcess(object):
  23. def __init__(self, model_name=None,
  24. key=None,
  25. thresh=0.3,
  26. box_thresh=0.6,
  27. max_candidates=1000,
  28. unclip_ratio=1.5,
  29. use_dilation=False,
  30. score_mode="fast",
  31. **kwargs):
  32. if model_name is None:
  33. model_name = ["student"]
  34. self.model_name = model_name
  35. self.key = key
  36. self.post_process = DBPostProcess(thresh=thresh,
  37. box_thresh=box_thresh,
  38. max_candidates=max_candidates,
  39. unclip_ratio=unclip_ratio,
  40. use_dilation=use_dilation,
  41. score_mode=score_mode)
  42. def __call__(self, predicts, shape_list):
  43. results = {}
  44. for k in self.model_name:
  45. results[k] = self.post_process(predicts[k].detach().cpu().numpy(), shape_list=shape_list)
  46. return results
  47. class DBPostProcess(object):
  48. """
  49. The post process for Differentiable Binarization (DB).
  50. """
  51. def __init__(self,
  52. thresh=0.6,
  53. box_thresh=0.6,
  54. max_candidates=1000,
  55. unclip_ratio=1.5,
  56. use_dilation=False,
  57. **kwargs):
  58. self.thresh = thresh
  59. self.box_thresh = box_thresh
  60. self.max_candidates = max_candidates
  61. self.unclip_ratio = unclip_ratio
  62. self.min_size = 3
  63. self.dilation_kernel = None if not use_dilation else np.array(
  64. [[1, 1], [1, 1]])
  65. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  66. '''
  67. _bitmap: single map with shape (1, H, W),
  68. whose values are binarized as {0, 1}
  69. '''
  70. bitmap = _bitmap
  71. height, width = bitmap.shape
  72. bitmap = (bitmap * 255).astype(np.uint8)
  73. # structure_element = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
  74. # bitmap = cv2.morphologyEx(bitmap, cv2.MORPH_CLOSE, structure_element)
  75. # cv2.imwrite('bin.jpg',bitmap)
  76. if cv2.__version__.startswith('3'):
  77. _, contours, _ = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  78. elif cv2.__version__.startswith('4'):
  79. contours, _ = cv2.findContours(bitmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  80. else:
  81. raise NotImplementedError(f'opencv {cv2.__version__} not support')
  82. num_contours = min(len(contours), self.max_candidates)
  83. boxes = []
  84. scores = []
  85. for index in range(num_contours):
  86. contour = contours[index]
  87. points, sside = self.get_mini_boxes(contour)
  88. if sside < self.min_size:
  89. continue
  90. points = np.array(points)
  91. score = self.box_score_fast(pred, points.reshape(-1, 2))
  92. # score = self.box_score_slow(pred, contour)
  93. if score < self.box_thresh:
  94. continue
  95. box = self.unclip(points).reshape(-1, 1, 2)
  96. box, sside = self.get_mini_boxes(box)
  97. if sside < self.min_size + 2:
  98. continue
  99. box = np.array(box)
  100. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  101. box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
  102. boxes.append(box.astype(np.int16))
  103. scores.append(score)
  104. # try:
  105. # poly = contours[index]
  106. # # cv2.drawContours(debug_mat, poly, -1, (111, 90, 255), -1)
  107. #
  108. # epsilon = 0.001 * cv2.arcLength(poly, True)
  109. # approx = cv2.approxPolyDP(poly, epsilon, True)
  110. # points = approx.reshape((-1, 2))
  111. # if points.shape[0] < 4:
  112. # continue
  113. # score = self.box_score_fast(pred, points)
  114. # if score < self.box_thresh:
  115. # continue
  116. # poly = self.unclip(points)
  117. # if len(poly) == 0 or isinstance(poly[0], list):
  118. # continue
  119. # poly = poly.reshape(-1, 2)
  120. #
  121. # # box, sside = self.get_mini_boxes(poly)
  122. # # if sside < self.min_size + 2:
  123. # # continue
  124. # # box = np.array(box)
  125. # box=np.array(poly)
  126. #
  127. # box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  128. # box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height)
  129. # boxes.append(box.astype(np.int16).flatten().tolist())
  130. # scores.append(score)
  131. # except:
  132. # print('1')
  133. # pass
  134. return boxes, scores
  135. def unclip(self, box):
  136. unclip_ratio = self.unclip_ratio
  137. poly = Polygon(box)
  138. distance = poly.area * unclip_ratio / poly.length
  139. offset = pyclipper.PyclipperOffset()
  140. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  141. expanded = np.array(offset.Execute(distance))
  142. return expanded
  143. def get_mini_boxes(self, contour):
  144. try:
  145. rotated_box = cv2.minAreaRect(contour)
  146. except:
  147. print(len(contour))
  148. return None, 0
  149. box_points = cv2.boxPoints(rotated_box)
  150. rotated_points = clockwise_sort_points(box_points)
  151. rotated_points = list(rotated_points)
  152. return rotated_points, min(rotated_box[1])
  153. def box_score_fast(self, bitmap, _box):
  154. h, w = bitmap.shape[:2]
  155. box = _box.copy()
  156. xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
  157. xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
  158. ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
  159. ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
  160. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  161. box[:, 0] = box[:, 0] - xmin
  162. box[:, 1] = box[:, 1] - ymin
  163. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  164. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  165. def box_score_slow(self, bitmap, contour):
  166. '''
  167. box_score_slow: use polyon mean score as the mean score
  168. '''
  169. h, w = bitmap.shape[:2]
  170. contour = contour.copy()
  171. contour = np.reshape(contour, (-1, 2))
  172. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  173. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  174. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  175. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  176. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  177. contour[:, 0] = contour[:, 0] - xmin
  178. contour[:, 1] = contour[:, 1] - ymin
  179. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  180. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  181. def __call__(self, outs_dict, shape_list):
  182. pred = outs_dict
  183. pred = pred[:, 0, :, :]
  184. segmentation = np.zeros_like(pred, dtype=np.float32)
  185. np.putmask(segmentation, pred > self.thresh, pred)
  186. boxes_batch = []
  187. scores_batch = []
  188. for batch_index in range(pred.shape[0]):
  189. src_h, src_w = shape_list[batch_index]
  190. if self.dilation_kernel is not None:
  191. mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel)
  192. else:
  193. mask = segmentation[batch_index]
  194. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, src_w, src_h, )
  195. boxes_batch.append(boxes)
  196. scores_batch.append(scores)
  197. return boxes_batch, scores_batch