extract_table.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import re
  2. import time
  3. import traceback
  4. import cv2
  5. from pdfminer.layout import LTLine
  6. # from botr.nsp.predict import nsp_predict
  7. from botr.rules.get_table_by_rules import get_table_by_rule
  8. from botr.utils import line_iou, get_table_iou
  9. from format_convert.convert_need_interface import from_yolo_interface
  10. from format_convert.utils import log, np2bytes
  11. def b_table_process(list_line, list_text_boxes, list_cell, table_location):
  12. def merge_textbox(textbox_list, in_objs):
  13. delete_obj = []
  14. threshold = 5
  15. textbox_list.sort(key=lambda x:x.bbox[0])
  16. for k in range(len(textbox_list)):
  17. tb1 = textbox_list[k]
  18. if tb1 not in in_objs and tb1 not in delete_obj:
  19. for m in range(k+1, len(textbox_list)):
  20. tb2 = textbox_list[m]
  21. if tb2 in in_objs:
  22. continue
  23. if abs(tb1.bbox[1]-tb2.bbox[1]) <= threshold \
  24. and abs(tb1.bbox[3]-tb2.bbox[3]) <= threshold:
  25. if tb1.bbox[0] <= tb2.bbox[0]:
  26. tb1.text = tb1.text + tb2.text
  27. else:
  28. tb1.text = tb2.text + tb1.text
  29. tb1.bbox[0] = min(tb1.bbox[0], tb2.bbox[0])
  30. tb1.bbox[2] = max(tb1.bbox[2], tb2.bbox[2])
  31. delete_obj.append(tb2)
  32. for _obj in delete_obj:
  33. if _obj in textbox_list:
  34. textbox_list.remove(_obj)
  35. return textbox_list
  36. try:
  37. if list_line:
  38. from format_convert.convert_tree import TableLine
  39. list_lines = []
  40. for line in list_line:
  41. list_lines.append(LTLine(1, (line[0], line[1]), (line[2], line[3])))
  42. # 先拿出在表格区域里的TextBox
  43. area_list_text_boxes = []
  44. threshold = 7
  45. for t_b in list_text_boxes:
  46. bbox = t_b.bbox
  47. if table_location[1] - threshold <= bbox[1] <= bbox[3] <= table_location[3] + threshold:
  48. area_list_text_boxes.append(t_b)
  49. # 对TextBox进行分行,否则同样一行有些框偏上有些偏下,影响文本顺序
  50. area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
  51. current_y = area_list_text_boxes[0].bbox[1]
  52. current_y2 = area_list_text_boxes[0].bbox[3]
  53. # threshold = 2.
  54. threshold = max(2., 1/3 * abs(current_y2 - current_y))
  55. for t_b in area_list_text_boxes:
  56. bbox = t_b.bbox
  57. if current_y - threshold <= bbox[1] <= current_y + threshold:
  58. t_b.bbox[1] = current_y
  59. else:
  60. current_y = bbox[1]
  61. area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
  62. # list_cell 转化为 LineTable形式
  63. tables = []
  64. obj_in_table = []
  65. table_dict = {'bbox': table_location}
  66. row_list = []
  67. for row in list_cell:
  68. col_list = []
  69. for col in row:
  70. col_dict = {'bbox': (col[0][0], col[0][1], col[1][0], col[1][1]),
  71. 'rowspan': 1, 'columnspan': 1, 'text': ''}
  72. for t_b in area_list_text_boxes:
  73. if t_b in obj_in_table:
  74. continue
  75. text = re.sub('\s', '', t_b.text)
  76. bbox = t_b.bbox
  77. iou = get_table_iou(col[0][0], col[0][1], col[1][0], col[1][1],
  78. bbox[0], bbox[1], bbox[2], bbox[3])
  79. if iou >= 0.3:
  80. col_dict['text'] += text
  81. obj_in_table.append(t_b)
  82. col_list.append(col_dict)
  83. row_list.append(col_list)
  84. table_dict['table'] = row_list
  85. tables.append(table_dict)
  86. # print('b_table_process tables', tables)
  87. # 合并同一行textbox
  88. # list_text_boxes = merge_textbox(list_text_boxes, obj_in_table)
  89. return list_text_boxes, tables, obj_in_table
  90. else:
  91. return list_text_boxes, [], set()
  92. except:
  93. traceback.print_exc()
  94. return [-8], [-8], [-8]
  95. def get_text_box_obj(_text_list, _bbox_list):
  96. from format_convert.convert_tree import TextBox
  97. _text_box_list = []
  98. for i in range(len(_bbox_list)):
  99. bbox = _bbox_list[i]
  100. b_text = _text_list[i]
  101. _text_box_list.append(TextBox([bbox[0][0], bbox[0][1],
  102. bbox[2][0], bbox[2][1]], b_text))
  103. return _text_box_list
  104. def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
  105. log('start')
  106. # 检测无边框表格
  107. start_time_all = time.time()
  108. start_time = time.time()
  109. img_bytes = np2bytes(img)
  110. b_table_list = from_yolo_interface(img_bytes)
  111. log('yolo detect cost: ' + str(time.time()-start_time))
  112. b_table_list = b_table_list[0]
  113. if not b_table_list:
  114. log('detect not b_table_list')
  115. return [], [], []
  116. # if show:
  117. # for b_table in b_table_list:
  118. # # for line in b_table:
  119. # cv2.rectangle(img, (int(b_table[0]), int(b_table[1])), (int(b_table[2]), int(b_table[3])),
  120. # (0, 0, 255), 2)
  121. # cv2.namedWindow('b_table', cv2.WINDOW_NORMAL)
  122. # cv2.imshow('b_table', img)
  123. # cv2.waitKey(0)
  124. if show:
  125. print('b_table_list', b_table_list)
  126. print('table_list', table_list)
  127. # 排除otr结果
  128. b_table_location_list = []
  129. for b_table in b_table_list:
  130. # print('b_table', b_table)
  131. min_x, min_y = 1000000, 1000000
  132. max_x, max_y = 0, 0
  133. # for line in b_table:
  134. if b_table[1] < min_y:
  135. min_y = b_table[1]
  136. if b_table[3] > max_y:
  137. max_y = b_table[3]
  138. if b_table[0] < min_x:
  139. min_x = b_table[0]
  140. if b_table[2] > max_x:
  141. max_x = b_table[2]
  142. b_loc = [min_x, min_y, max_x, max_y, b_table[4]]
  143. inter_flag = False
  144. for table in table_list:
  145. loc = table.get('bbox')
  146. rows = table.get('table')
  147. iou = line_iou([[0, loc[1]], [0, loc[3]]], [[0, b_loc[1]], [0, b_loc[3]]], axis=1)
  148. if iou > 0.3:
  149. # if len(rows) <= 1:
  150. # if loc[1] < b_loc[1] < loc[3] < b_loc[3]:
  151. # b_loc[1] = loc[3]
  152. # if b_loc[1] < loc[1] < b_loc[3] < loc[3]:
  153. # b_loc[3] = loc[1]
  154. # continue
  155. inter_flag = True
  156. # cv2.rectangle(img, [int(loc[0]), int(loc[1])], [int(loc[2]), int(loc[3])], (0, 0, 255))
  157. # cv2.rectangle(img, [int(b_loc[0]), int(b_loc[1])], [int(b_loc[2]), int(b_loc[3])], (0, 0, 255))
  158. # cv2.imshow('inter', img)
  159. # cv2.waitKey(0)
  160. break
  161. if not inter_flag:
  162. b_table_location_list.append(b_loc)
  163. if not b_table_location_list:
  164. log('except otr, not b_table_location_list')
  165. return [], [], []
  166. if show:
  167. print('len(b_table_location_list)', len(b_table_location_list))
  168. # 排除有重合的,取概率大的
  169. if len(b_table_location_list) > 1:
  170. temp_list = []
  171. used_b_loc = []
  172. for i in range(len(b_table_location_list)):
  173. b_loc1 = b_table_location_list[i]
  174. if b_loc1 in used_b_loc:
  175. continue
  176. inter_flag = False
  177. for j in range(i+1, len(b_table_location_list)):
  178. b_loc2 = b_table_location_list[j]
  179. iou = line_iou([[0, b_loc1[1]], [0, b_loc1[3]]], [[0, b_loc2[1]], [0, b_loc2[3]]], axis=1)
  180. if show:
  181. print('iou2', iou)
  182. if iou > 0.3:
  183. inter_flag = True
  184. break
  185. if inter_flag:
  186. used_b_loc.append(b_loc2)
  187. if b_loc1[4] >= b_loc2[4]:
  188. temp_list.append(b_loc1[:4])
  189. else:
  190. temp_list.append(b_loc2[:4])
  191. else:
  192. temp_list.append(b_loc1[:4])
  193. b_table_location_list = temp_list
  194. if show:
  195. for b_loc in b_table_location_list:
  196. cv2.rectangle(img, (int(b_loc[0]), int(b_loc[1])), (int(b_loc[2]), int(b_loc[3])),
  197. (0, 0, 255), 2)
  198. cv2.namedWindow('b_table_no_otr', cv2.WINDOW_NORMAL)
  199. cv2.imshow('b_table_no_otr', img)
  200. cv2.waitKey(0)
  201. table_list = []
  202. obj_in_table_list = []
  203. # print('len(b_table_location_list)', len(b_table_location_list))
  204. for b_loc in b_table_location_list:
  205. area_text_list = []
  206. area_bbox_list = []
  207. threshold = 5
  208. for i, bbox in enumerate(bbox_list):
  209. if b_loc[1] - threshold <= bbox[0][1] <= bbox[2][1] <= b_loc[3] + threshold:
  210. area_bbox_list.append(bbox)
  211. area_text_list.append(text_list[i])
  212. # 根据ocr bbox,规则生成表格线
  213. start_time = time.time()
  214. line_list, cell_list, table_location, bbox_text_dict = get_table_by_rule(img, area_text_list, area_bbox_list, b_loc, show=show)
  215. if not table_location:
  216. log('get_table_by_rule not table_location')
  217. continue
  218. # 获取最新的text_list, bbox_list
  219. area_text_list, area_bbox_list = [], []
  220. for key in bbox_text_dict.keys():
  221. area_bbox_list.append(eval(key))
  222. area_text_list.append(bbox_text_dict.get(key))
  223. b_text_box_list = get_text_box_obj(area_text_list, area_bbox_list)
  224. log('get_table_by_rule cost: ' + str(time.time()-start_time))
  225. # 根据表格线生成单元格
  226. start_time = time.time()
  227. b_text_box_list, _table_list, _obj_in_table_list = b_table_process(line_list, b_text_box_list, cell_list, table_location)
  228. table_list += _table_list
  229. obj_in_table_list += _obj_in_table_list
  230. log('b_table_process cost: ' + str(time.time()-start_time))
  231. # if not table_list:
  232. # log('table_process not table_list')
  233. # return [], [], []
  234. if not _table_list:
  235. log('table_process not table_list')
  236. continue
  237. # 单元格合并,nsp模型
  238. # 使用hanlp分词,判断上下句是否该合并 顺便拉数据统计
  239. # 1. 上下句ab,ab相连得到c
  240. # 2.1 c分词,若ab相连处合为一个词语,则ab相连
  241. # 2.2 ab相连处不为一个词语,a, b分别分词
  242. # 2.2.1 若b的第一个词,从其中分第一个字给a,然后
  243. # near_col_list = []
  244. # table = _table_list[0].get('table')
  245. # col_cnt = len(table[0])
  246. # for c_cnt in range(col_cnt):
  247. # for i in range(len(table)-1):
  248. # t = table[i][c_cnt].get('text')
  249. # next_t = table[i+1][c_cnt].get('text')
  250. # if t and next_t:
  251. # near_col_list.append([t, next_t])
  252. # elif t and next_t == '':
  253. # if i+2 <= len(table)-1:
  254. # next_2_t = table[i+2][c_cnt].get('text')
  255. # near_col_list.append([t, next_2_t])
  256. #
  257. # is_next_list = nsp_predict(near_col_list, has_label=False)
  258. #
  259. # next_index = 0
  260. # for c_cnt in range(col_cnt):
  261. # # 先把一列里的需合并的打上标签
  262. # for i in range(len(table)-1):
  263. # t = table[i][c_cnt].get('text')
  264. # next_t = table[i+1][c_cnt].get('text')
  265. # if t and next_t:
  266. # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
  267. # next_index += 1
  268. # elif t and next_t == '':
  269. # if i+2 <= len(table)-1:
  270. # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
  271. # table[i+2][c_cnt]['is_next'] = is_next_list[next_index]
  272. # next_index += 1
  273. #
  274. # first_col = None
  275. # for i in range(len(table)):
  276. # if table[i][c_cnt].get('is_next'):
  277. # if first_col is None:
  278. # first_col = table[i-1][c_cnt]
  279. # first_col['text'] += table[i][c_cnt].get('text')
  280. # first_col['rowspan'] += 1
  281. # else:
  282. # first_col = None
  283. #
  284. # # 删除标签为True的
  285. # new_table = []
  286. # for row in table:
  287. # new_row = []
  288. # for col in row:
  289. # if col.get('is_next'):
  290. # continue
  291. # new_row.append(col)
  292. # new_table.append(new_row)
  293. #
  294. # _table_list[0]['table'] = new_table
  295. log('get_table finish ' + str(time.time() - start_time_all))
  296. return text_box_list, table_list, obj_in_table_list