extract_table.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import time
  2. import traceback
  3. import cv2
  4. from pdfminer.layout import LTLine
  5. # from botr.nsp.predict import nsp_predict
  6. from botr.rules.get_table_by_rules import get_table_by_rule
  7. from botr.utils import line_iou, get_table_iou
  8. from format_convert.convert_need_interface import from_yolo_interface
  9. from format_convert.utils import log, np2bytes
  10. def b_table_process(list_line, list_text_boxes, list_cell, table_location):
  11. def merge_textbox(textbox_list, in_objs):
  12. delete_obj = []
  13. threshold = 5
  14. textbox_list.sort(key=lambda x:x.bbox[0])
  15. for k in range(len(textbox_list)):
  16. tb1 = textbox_list[k]
  17. if tb1 not in in_objs and tb1 not in delete_obj:
  18. for m in range(k+1, len(textbox_list)):
  19. tb2 = textbox_list[m]
  20. if tb2 in in_objs:
  21. continue
  22. if abs(tb1.bbox[1]-tb2.bbox[1]) <= threshold \
  23. and abs(tb1.bbox[3]-tb2.bbox[3]) <= threshold:
  24. if tb1.bbox[0] <= tb2.bbox[0]:
  25. tb1.text = tb1.text + tb2.text
  26. else:
  27. tb1.text = tb2.text + tb1.text
  28. tb1.bbox[0] = min(tb1.bbox[0], tb2.bbox[0])
  29. tb1.bbox[2] = max(tb1.bbox[2], tb2.bbox[2])
  30. delete_obj.append(tb2)
  31. for _obj in delete_obj:
  32. if _obj in textbox_list:
  33. textbox_list.remove(_obj)
  34. return textbox_list
  35. try:
  36. if list_line:
  37. from format_convert.convert_tree import TableLine
  38. list_lines = []
  39. for line in list_line:
  40. list_lines.append(LTLine(1, (line[0], line[1]), (line[2], line[3])))
  41. # 先拿出在表格区域里的TextBox
  42. area_list_text_boxes = []
  43. threshold = 7
  44. for t_b in list_text_boxes:
  45. bbox = t_b.bbox
  46. if table_location[1] - threshold <= bbox[1] <= bbox[3] <= table_location[3] + threshold:
  47. area_list_text_boxes.append(t_b)
  48. # 对TextBox进行分行,否则同样一行有些框偏上有些偏下,影响文本顺序
  49. area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
  50. current_y = area_list_text_boxes[0].bbox[1]
  51. threshold = 2.
  52. for t_b in area_list_text_boxes:
  53. bbox = t_b.bbox
  54. if current_y - threshold <= bbox[1] <= current_y + threshold:
  55. t_b.bbox[1] = current_y
  56. else:
  57. current_y = bbox[1]
  58. area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
  59. # list_cell 转化为 LineTable形式
  60. tables = []
  61. obj_in_table = []
  62. table_dict = {'bbox': table_location}
  63. row_list = []
  64. for row in list_cell:
  65. col_list = []
  66. for col in row:
  67. col_dict = {'bbox': (col[0][0], col[0][1], col[1][0], col[1][1]),
  68. 'rowspan': 1, 'columnspan': 1, 'text': ''}
  69. for t_b in area_list_text_boxes:
  70. if t_b in obj_in_table:
  71. continue
  72. text = t_b.text
  73. bbox = t_b.bbox
  74. iou = get_table_iou(col[0][0], col[0][1], col[1][0], col[1][1],
  75. bbox[0], bbox[1], bbox[2], bbox[3])
  76. if iou >= 0.3:
  77. col_dict['text'] += text
  78. obj_in_table.append(t_b)
  79. col_list.append(col_dict)
  80. row_list.append(col_list)
  81. table_dict['table'] = row_list
  82. tables.append(table_dict)
  83. # print('b_table_process tables', tables)
  84. # 合并同一行textbox
  85. # list_text_boxes = merge_textbox(list_text_boxes, obj_in_table)
  86. return list_text_boxes, tables, obj_in_table
  87. else:
  88. return list_text_boxes, [], set()
  89. except:
  90. traceback.print_exc()
  91. return [-8], [-8], [-8]
  92. def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
  93. log('start')
  94. # 检测无边框表格
  95. start_time_all = time.time()
  96. start_time = time.time()
  97. img_bytes = np2bytes(img)
  98. b_table_list = from_yolo_interface(img_bytes)
  99. log('yolo detect cost: ' + str(time.time()-start_time))
  100. b_table_list = b_table_list[0]
  101. if not b_table_list:
  102. log('detect not b_table_list')
  103. return [], [], []
  104. if show:
  105. for b_table in b_table_list:
  106. # for line in b_table:
  107. cv2.rectangle(img, (int(b_table[0]), int(b_table[1])), (int(b_table[2]), int(b_table[3])),
  108. (0, 0, 255))
  109. cv2.imshow('b_table', img)
  110. cv2.waitKey(0)
  111. if show:
  112. print('b_table_list', b_table_list)
  113. # 排除otr结果
  114. b_table_location_list = []
  115. for b_table in b_table_list:
  116. # print('b_table', b_table)
  117. min_x, min_y = 1000000, 1000000
  118. max_x, max_y = 0, 0
  119. # for line in b_table:
  120. if b_table[1] < min_y:
  121. min_y = b_table[1]
  122. if b_table[3] > max_y:
  123. max_y = b_table[3]
  124. if b_table[0] < min_x:
  125. min_x = b_table[0]
  126. if b_table[2] > max_x:
  127. max_x = b_table[2]
  128. b_loc = [min_x, min_y, max_x, max_y]
  129. inter_flag = False
  130. for table in table_list:
  131. loc = table.get('bbox')
  132. rows = table.get('table')
  133. if line_iou([[0, loc[1]], [0, loc[3]]], [[0, b_loc[1]], [0, b_loc[3]]], axis=1) > 0.3:
  134. if len(rows) <= 1:
  135. if loc[1] < b_loc[1] < loc[3] < b_loc[3]:
  136. b_loc[1] = loc[3]
  137. if b_loc[1] < loc[1] < b_loc[3] < loc[3]:
  138. b_loc[3] = loc[1]
  139. continue
  140. inter_flag = True
  141. # cv2.rectangle(img, [int(loc[0]), int(loc[1])], [int(loc[2]), int(loc[3])], (0, 0, 255))
  142. # cv2.rectangle(img, [int(b_loc[0]), int(b_loc[1])], [int(b_loc[2]), int(b_loc[3])], (0, 0, 255))
  143. # cv2.imshow('inter', img)
  144. # cv2.waitKey(0)
  145. break
  146. if not inter_flag:
  147. b_table_location_list.append(b_loc)
  148. if not b_table_location_list:
  149. log('not b_table_location_list')
  150. return [], [], []
  151. if show:
  152. for b_loc in b_table_location_list:
  153. cv2.rectangle(img, (int(b_loc[0]), int(b_loc[1])), (int(b_loc[2]), int(b_loc[3])),
  154. (0, 0, 255))
  155. cv2.imshow('b_table no otr', img)
  156. cv2.waitKey(0)
  157. table_list = []
  158. obj_in_table_list = []
  159. # print('len(b_table_location_list)', len(b_table_location_list))
  160. for b_loc in b_table_location_list:
  161. area_text_list = []
  162. area_bbox_list = []
  163. threshold = 5
  164. for i, bbox in enumerate(bbox_list):
  165. if b_loc[1] - threshold <= bbox[0][1] <= bbox[2][1] <= b_loc[3] + threshold:
  166. area_bbox_list.append(bbox)
  167. area_text_list.append(text_list[i])
  168. # 根据ocr bbox,规则生成表格线
  169. start_time = time.time()
  170. line_list, cell_list, table_location = get_table_by_rule(img, area_text_list, area_bbox_list, b_loc)
  171. if not table_location:
  172. log('get_table_by_rule not table_location')
  173. continue
  174. log('get_table_by_rule cost: ' + str(time.time()-start_time))
  175. # 根据表格线生成单元格
  176. start_time = time.time()
  177. text_box_list, _table_list, _obj_in_table_list = b_table_process(line_list, text_box_list, cell_list, table_location)
  178. table_list += _table_list
  179. obj_in_table_list += _obj_in_table_list
  180. log('b_table_process cost: ' + str(time.time()-start_time))
  181. # if not table_list:
  182. # log('table_process not table_list')
  183. # return [], [], []
  184. if not _table_list:
  185. log('table_process not table_list')
  186. continue
  187. # 单元格合并,nsp模型
  188. # 使用hanlp分词,判断上下句是否该合并 顺便拉数据统计
  189. # 1. 上下句ab,ab相连得到c
  190. # 2.1 c分词,若ab相连处合为一个词语,则ab相连
  191. # 2.2 ab相连处不为一个词语,a, b分别分词
  192. # 2.2.1 若b的第一个词,从其中分第一个字给a,然后
  193. # near_col_list = []
  194. # table = _table_list[0].get('table')
  195. # col_cnt = len(table[0])
  196. # for c_cnt in range(col_cnt):
  197. # for i in range(len(table)-1):
  198. # t = table[i][c_cnt].get('text')
  199. # next_t = table[i+1][c_cnt].get('text')
  200. # if t and next_t:
  201. # near_col_list.append([t, next_t])
  202. # elif t and next_t == '':
  203. # if i+2 <= len(table)-1:
  204. # next_2_t = table[i+2][c_cnt].get('text')
  205. # near_col_list.append([t, next_2_t])
  206. #
  207. # is_next_list = nsp_predict(near_col_list, has_label=False)
  208. #
  209. # next_index = 0
  210. # for c_cnt in range(col_cnt):
  211. # # 先把一列里的需合并的打上标签
  212. # for i in range(len(table)-1):
  213. # t = table[i][c_cnt].get('text')
  214. # next_t = table[i+1][c_cnt].get('text')
  215. # if t and next_t:
  216. # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
  217. # next_index += 1
  218. # elif t and next_t == '':
  219. # if i+2 <= len(table)-1:
  220. # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
  221. # table[i+2][c_cnt]['is_next'] = is_next_list[next_index]
  222. # next_index += 1
  223. #
  224. # first_col = None
  225. # for i in range(len(table)):
  226. # if table[i][c_cnt].get('is_next'):
  227. # if first_col is None:
  228. # first_col = table[i-1][c_cnt]
  229. # first_col['text'] += table[i][c_cnt].get('text')
  230. # first_col['rowspan'] += 1
  231. # else:
  232. # first_col = None
  233. #
  234. # # 删除标签为True的
  235. # new_table = []
  236. # for row in table:
  237. # new_row = []
  238. # for col in row:
  239. # if col.get('is_next'):
  240. # continue
  241. # new_row.append(col)
  242. # new_table.append(new_row)
  243. #
  244. # _table_list[0]['table'] = new_table
  245. log('get_table finish ' + str(time.time() - start_time_all))
  246. return text_box_list, table_list, obj_in_table_list