table_utils.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import numpy as np
  2. import cv2
  3. def shrink_bbox(img, bbox_list):
  4. def return_first_black_index(image_np):
  5. lower = np.array([0, 0, 0])
  6. upper = np.array([150, 150, 150])
  7. mask = cv2.inRange(image_np, lower, upper)
  8. black_index_list = np.where(mask != 0)
  9. return black_index_list
  10. new_bbox_list = []
  11. for bbox in bbox_list:
  12. img_bbox = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
  13. if 0 in img_bbox.shape:
  14. new_bbox_list.append(bbox)
  15. continue
  16. # 左右上下开始扫描,碰到黑像素即停
  17. index_list = return_first_black_index(img_bbox[:, :, :])
  18. if index_list[0].size == 0 or index_list[1].size == 0:
  19. new_bbox_list.append(bbox)
  20. continue
  21. min_h = index_list[0][0]
  22. max_h = index_list[0][-1]
  23. img_bbox1 = np.swapaxes(img_bbox, 0, 1)
  24. index_list = return_first_black_index(img_bbox1[:, :, :])
  25. if index_list[0].size == 0 or index_list[1].size == 0:
  26. new_bbox_list.append(bbox)
  27. continue
  28. min_w = index_list[0][0]
  29. max_w = index_list[0][-1]
  30. real_min_w = bbox[0][0] + min_w
  31. real_max_w = bbox[0][0] + max_w
  32. real_min_h = bbox[0][1] + min_h
  33. real_max_h = bbox[0][1] + max_h
  34. new_bbox = [[real_min_w, real_min_h], [real_min_w, real_max_h], [real_max_w, real_max_h], [real_max_w, real_min_h]]
  35. new_bbox_list.append(new_bbox)
  36. # cv2.imshow('img', img_bbox)
  37. # cv2.imshow('shrink', img[int(new_bbox[0][1]):int(new_bbox[2][1]), int(new_bbox[0][0]):int(new_bbox[2][0]), :])
  38. # cv2.waitKey(0)
  39. return new_bbox_list
  40. def split_bbox(img, bbox, bbox_text_dict):
  41. text = bbox_text_dict.get(str(bbox))
  42. sub_img = img[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
  43. split_line_list = []
  44. last_i_status = 1
  45. # 从左到右遍历img
  46. for i in range(1, sub_img.shape[1]):
  47. # 若这一列黑色像素超过一定值
  48. if np.where(sub_img[:, i, :] < 200)[0].size > sub_img.shape[0]/5:
  49. i_status = 0
  50. else:
  51. i_status = 1
  52. # 异或,上个像素列为黑且这个像素列为白,或上个像素列为白且这个像素列为黑
  53. if last_i_status ^ i_status:
  54. split_line_list.append(int(i))
  55. last_i_status = i_status
  56. # 两条分割线太近的去重
  57. min_len = 5
  58. last_l = split_line_list[0]
  59. temp_list = [split_line_list[0]]
  60. for l in split_line_list[1:]:
  61. if l - last_l > min_len:
  62. temp_list.append(l)
  63. last_l = l
  64. split_line_list = temp_list
  65. # 若两个分割线间无黑像素,则是应该分割的
  66. split_pair_list = []
  67. last_line = split_line_list[0]
  68. for line in split_line_list[1:]:
  69. print('last_line, line', last_line, line, np.where(sub_img[:, last_line:line, :] < 100)[0].size)
  70. if line - last_line >= 10 and np.where(sub_img[:, last_line:line, :] < 100)[0].size < 10:
  71. split_pair_list.append([last_line, line])
  72. last_line = line
  73. print('split_pair_list', split_pair_list)
  74. for l in split_line_list:
  75. l = int(l + bbox[0][0])
  76. cv2.line(img, (l, int(bbox[0][1])), (l, int(bbox[2][1])), (0, 255, 0), 2)
  77. cv2.rectangle(img, (int(bbox[0][0]), int(bbox[0][1])), (int(bbox[2][0]), int(bbox[2][1])),
  78. (0, 0, 255), 1)
  79. cv2.imshow('img', img)
  80. cv2.waitKey(0)
  81. # 分割得到新bbox
  82. split_bbox_list = []
  83. if split_pair_list:
  84. start_line = 0
  85. for line1, line2 in split_pair_list:
  86. w1 = start_line + bbox[0][0]
  87. w2 = line1 + bbox[0][0]
  88. start_line = line2
  89. split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []])
  90. w1 = start_line + bbox[0][0]
  91. w2 = bbox[2][0]
  92. split_bbox_list.append([[w1, bbox[0][1]], [], [w2, bbox[2][1]], []])
  93. print('split_bbox_list', split_bbox_list)
  94. # 计算每个字长度
  95. all_len = 0
  96. bbox_len_list = []
  97. for bbox in split_bbox_list:
  98. _len = abs(bbox[2][0] - bbox[0][0])
  99. all_len += _len
  100. bbox_len_list.append(_len)
  101. single_char_len = all_len / len(text)
  102. # 根据bbox长度和单字长度比例计算得到截取后的text
  103. split_text_list = []
  104. text_start = 0
  105. for _len in bbox_len_list:
  106. text_num = int(_len / single_char_len + 0.5)
  107. text_end = text_start+text_num
  108. if text_end >= len(text):
  109. text_end = len(text)
  110. split_text_list.append(text[text_start:text_end])
  111. text_start = text_end
  112. print('split_text_list', split_text_list)
  113. # 更新bbox_text_dict
  114. for i, bbox in enumerate(split_bbox_list):
  115. bbox_text_dict[str(bbox)] = split_text_list[i]
  116. return split_bbox_list, bbox_text_dict
  117. def count_black(image_np, threshold=150):
  118. lower = np.array([0, 0, 0])
  119. upper = np.array([threshold, threshold, threshold])
  120. mask = cv2.inRange(image_np, lower, upper)
  121. cnt = np.sum(mask != 0)
  122. # print("count color ", cnt)
  123. return cnt
  124. def get_points_by_line(img, row_lines, col_lines):
  125. row_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
  126. col_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
  127. for r in row_lines:
  128. cv2.line(row_img, [r[0], r[1]], [r[2], r[3]], (255, 255, 255), 1)
  129. for c in col_lines:
  130. cv2.line(col_img, [c[0], c[1]], [c[2], c[3]], (255, 255, 255), 1)
  131. point_img = np.bitwise_and(row_img, col_img)
  132. # 识别黑白图中的白色交叉点,将横纵坐标取出
  133. ys, xs = np.where(point_img > 0)
  134. points = []
  135. for i in range(len(xs)):
  136. points.append((xs[i], ys[i]))
  137. points.sort(key=lambda x: (x[0], x[1]))
  138. return points
  139. def get_table_bbox_list(img, area_row_lines, area_col_lines, table_location_list, bbox_list):
  140. area_table_bbox_list = []
  141. area_table_cell_list = []
  142. for i in range(len(table_location_list)):
  143. row_lines = area_row_lines[i]
  144. col_lines = area_col_lines[i]
  145. # 求线交点
  146. cross_points = get_points_by_line(img, row_lines, col_lines)
  147. # for p in cross_points:
  148. # cv2.circle(img, p, 2, (0, 0, 255), 2)
  149. # cv2.imshow('cross_points', img)
  150. # 交点分行
  151. cross_points.sort(key=lambda x: (x[1], x[0]))
  152. row_point_list = []
  153. if not cross_points:
  154. area_table_bbox_list.append([])
  155. area_table_cell_list.append([])
  156. continue
  157. current_row = [cross_points[0]]
  158. for p in cross_points[1:]:
  159. if current_row[0][1] == p[1]:
  160. current_row.append(p)
  161. else:
  162. row_point_list.append(current_row)
  163. current_row = [p]
  164. if current_row:
  165. row_point_list.append(current_row)
  166. # bbox以表格格式排列
  167. used_bbox_list = []
  168. row_list = []
  169. row_cell_list = []
  170. for j in range(1, len(row_point_list)):
  171. last_row = row_point_list[j-1]
  172. row = row_point_list[j]
  173. col_list = []
  174. col_cell_list = []
  175. for k in range(1, len(row)):
  176. last_p = last_row[k-1]
  177. p = row[k]
  178. cell = []
  179. for bbox in bbox_list:
  180. if bbox in used_bbox_list:
  181. continue
  182. bbox_h_center = (bbox[0][1]+bbox[2][1]) / 2
  183. bbox_w_center = (bbox[0][0]+bbox[2][0]) / 2
  184. if last_p[1] <= bbox_h_center <= p[1] and last_p[0] <= bbox_w_center <= p[0]:
  185. cell.append(bbox)
  186. used_bbox_list.append(bbox)
  187. col_list.append(cell)
  188. col_cell_list.append([last_p, p])
  189. row_list.append(col_list)
  190. row_cell_list.append(col_cell_list)
  191. area_table_bbox_list.append(row_list)
  192. area_table_cell_list.append(row_cell_list)
  193. return area_table_bbox_list[0], area_table_cell_list[0]