extract_table.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import re
  2. import time
  3. import traceback
  4. import cv2
  5. from pdfminer.layout import LTLine
  6. # from botr.nsp.predict import nsp_predict
  7. from botr.rules.get_table_by_rules import get_table_by_rule
  8. from botr.utils import line_iou, get_table_iou
  9. from format_convert.convert_need_interface import from_yolo_interface
  10. from format_convert.utils import log, np2bytes
  11. def b_table_process(list_line, list_text_boxes, list_cell, table_location):
  12. def merge_textbox(textbox_list, in_objs):
  13. delete_obj = []
  14. threshold = 5
  15. textbox_list.sort(key=lambda x:x.bbox[0])
  16. for k in range(len(textbox_list)):
  17. tb1 = textbox_list[k]
  18. if tb1 not in in_objs and tb1 not in delete_obj:
  19. for m in range(k+1, len(textbox_list)):
  20. tb2 = textbox_list[m]
  21. if tb2 in in_objs:
  22. continue
  23. if abs(tb1.bbox[1]-tb2.bbox[1]) <= threshold \
  24. and abs(tb1.bbox[3]-tb2.bbox[3]) <= threshold:
  25. if tb1.bbox[0] <= tb2.bbox[0]:
  26. tb1.text = tb1.text + tb2.text
  27. else:
  28. tb1.text = tb2.text + tb1.text
  29. tb1.bbox[0] = min(tb1.bbox[0], tb2.bbox[0])
  30. tb1.bbox[2] = max(tb1.bbox[2], tb2.bbox[2])
  31. delete_obj.append(tb2)
  32. for _obj in delete_obj:
  33. if _obj in textbox_list:
  34. textbox_list.remove(_obj)
  35. return textbox_list
  36. try:
  37. if list_line:
  38. from format_convert.convert_tree import TableLine
  39. list_lines = []
  40. for line in list_line:
  41. list_lines.append(LTLine(1, (line[0], line[1]), (line[2], line[3])))
  42. # 先拿出在表格区域里的TextBox
  43. area_list_text_boxes = []
  44. threshold = 7
  45. for t_b in list_text_boxes:
  46. bbox = t_b.bbox
  47. if table_location[1] - threshold <= bbox[1] <= bbox[3] <= table_location[3] + threshold:
  48. area_list_text_boxes.append(t_b)
  49. # 对TextBox进行分行,否则同样一行有些框偏上有些偏下,影响文本顺序
  50. area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
  51. current_y = area_list_text_boxes[0].bbox[1]
  52. threshold = 2.
  53. for t_b in area_list_text_boxes:
  54. bbox = t_b.bbox
  55. if current_y - threshold <= bbox[1] <= current_y + threshold:
  56. t_b.bbox[1] = current_y
  57. else:
  58. current_y = bbox[1]
  59. area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
  60. # list_cell 转化为 LineTable形式
  61. tables = []
  62. obj_in_table = []
  63. table_dict = {'bbox': table_location}
  64. row_list = []
  65. for row in list_cell:
  66. col_list = []
  67. for col in row:
  68. col_dict = {'bbox': (col[0][0], col[0][1], col[1][0], col[1][1]),
  69. 'rowspan': 1, 'columnspan': 1, 'text': ''}
  70. for t_b in area_list_text_boxes:
  71. if t_b in obj_in_table:
  72. continue
  73. text = re.sub('\s', '', t_b.text)
  74. bbox = t_b.bbox
  75. iou = get_table_iou(col[0][0], col[0][1], col[1][0], col[1][1],
  76. bbox[0], bbox[1], bbox[2], bbox[3])
  77. if iou >= 0.3:
  78. col_dict['text'] += text
  79. obj_in_table.append(t_b)
  80. col_list.append(col_dict)
  81. row_list.append(col_list)
  82. table_dict['table'] = row_list
  83. tables.append(table_dict)
  84. # print('b_table_process tables', tables)
  85. # 合并同一行textbox
  86. # list_text_boxes = merge_textbox(list_text_boxes, obj_in_table)
  87. return list_text_boxes, tables, obj_in_table
  88. else:
  89. return list_text_boxes, [], set()
  90. except:
  91. traceback.print_exc()
  92. return [-8], [-8], [-8]
  93. def get_text_box_obj(_text_list, _bbox_list):
  94. from format_convert.convert_tree import TextBox
  95. _text_box_list = []
  96. for i in range(len(_bbox_list)):
  97. bbox = _bbox_list[i]
  98. b_text = _text_list[i]
  99. _text_box_list.append(TextBox([bbox[0][0], bbox[0][1],
  100. bbox[2][0], bbox[2][1]], b_text))
  101. return _text_box_list
  102. def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
  103. log('start')
  104. # 检测无边框表格
  105. start_time_all = time.time()
  106. start_time = time.time()
  107. img_bytes = np2bytes(img)
  108. b_table_list = from_yolo_interface(img_bytes)
  109. log('yolo detect cost: ' + str(time.time()-start_time))
  110. b_table_list = b_table_list[0]
  111. if not b_table_list:
  112. log('detect not b_table_list')
  113. return [], [], []
  114. # if show:
  115. # for b_table in b_table_list:
  116. # # for line in b_table:
  117. # cv2.rectangle(img, (int(b_table[0]), int(b_table[1])), (int(b_table[2]), int(b_table[3])),
  118. # (0, 0, 255), 2)
  119. # cv2.namedWindow('b_table', cv2.WINDOW_NORMAL)
  120. # cv2.imshow('b_table', img)
  121. # cv2.waitKey(0)
  122. if show:
  123. print('b_table_list', b_table_list)
  124. print('table_list', table_list)
  125. # 排除otr结果
  126. b_table_location_list = []
  127. for b_table in b_table_list:
  128. # print('b_table', b_table)
  129. min_x, min_y = 1000000, 1000000
  130. max_x, max_y = 0, 0
  131. # for line in b_table:
  132. if b_table[1] < min_y:
  133. min_y = b_table[1]
  134. if b_table[3] > max_y:
  135. max_y = b_table[3]
  136. if b_table[0] < min_x:
  137. min_x = b_table[0]
  138. if b_table[2] > max_x:
  139. max_x = b_table[2]
  140. b_loc = [min_x, min_y, max_x, max_y, b_table[4]]
  141. inter_flag = False
  142. for table in table_list:
  143. loc = table.get('bbox')
  144. rows = table.get('table')
  145. iou = line_iou([[0, loc[1]], [0, loc[3]]], [[0, b_loc[1]], [0, b_loc[3]]], axis=1)
  146. if iou > 0.3:
  147. # if len(rows) <= 1:
  148. # if loc[1] < b_loc[1] < loc[3] < b_loc[3]:
  149. # b_loc[1] = loc[3]
  150. # if b_loc[1] < loc[1] < b_loc[3] < loc[3]:
  151. # b_loc[3] = loc[1]
  152. # continue
  153. inter_flag = True
  154. # cv2.rectangle(img, [int(loc[0]), int(loc[1])], [int(loc[2]), int(loc[3])], (0, 0, 255))
  155. # cv2.rectangle(img, [int(b_loc[0]), int(b_loc[1])], [int(b_loc[2]), int(b_loc[3])], (0, 0, 255))
  156. # cv2.imshow('inter', img)
  157. # cv2.waitKey(0)
  158. break
  159. if not inter_flag:
  160. b_table_location_list.append(b_loc)
  161. if not b_table_location_list:
  162. log('except otr, not b_table_location_list')
  163. return [], [], []
  164. if show:
  165. print('len(b_table_location_list)', len(b_table_location_list))
  166. # 排除有重合的,取概率大的
  167. if len(b_table_location_list) > 1:
  168. temp_list = []
  169. used_b_loc = []
  170. for i in range(len(b_table_location_list)):
  171. b_loc1 = b_table_location_list[i]
  172. if b_loc1 in used_b_loc:
  173. continue
  174. inter_flag = False
  175. for j in range(i+1, len(b_table_location_list)):
  176. b_loc2 = b_table_location_list[j]
  177. iou = line_iou([[0, b_loc1[1]], [0, b_loc1[3]]], [[0, b_loc2[1]], [0, b_loc2[3]]], axis=1)
  178. if show:
  179. print('iou2', iou)
  180. if iou > 0.3:
  181. inter_flag = True
  182. break
  183. if inter_flag:
  184. used_b_loc.append(b_loc2)
  185. if b_loc1[4] >= b_loc2[4]:
  186. temp_list.append(b_loc1[:4])
  187. else:
  188. temp_list.append(b_loc2[:4])
  189. else:
  190. temp_list.append(b_loc1[:4])
  191. b_table_location_list = temp_list
  192. if show:
  193. for b_loc in b_table_location_list:
  194. cv2.rectangle(img, (int(b_loc[0]), int(b_loc[1])), (int(b_loc[2]), int(b_loc[3])),
  195. (0, 0, 255), 2)
  196. cv2.namedWindow('b_table_no_otr', cv2.WINDOW_NORMAL)
  197. cv2.imshow('b_table_no_otr', img)
  198. cv2.waitKey(0)
  199. table_list = []
  200. obj_in_table_list = []
  201. # print('len(b_table_location_list)', len(b_table_location_list))
  202. for b_loc in b_table_location_list:
  203. area_text_list = []
  204. area_bbox_list = []
  205. threshold = 5
  206. for i, bbox in enumerate(bbox_list):
  207. if b_loc[1] - threshold <= bbox[0][1] <= bbox[2][1] <= b_loc[3] + threshold:
  208. area_bbox_list.append(bbox)
  209. area_text_list.append(text_list[i])
  210. # 根据ocr bbox,规则生成表格线
  211. start_time = time.time()
  212. line_list, cell_list, table_location, bbox_text_dict = get_table_by_rule(img, area_text_list, area_bbox_list, b_loc, show=show)
  213. if not table_location:
  214. log('get_table_by_rule not table_location')
  215. continue
  216. # 获取最新的text_list, bbox_list
  217. area_text_list, area_bbox_list = [], []
  218. for key in bbox_text_dict.keys():
  219. area_bbox_list.append(eval(key))
  220. area_text_list.append(bbox_text_dict.get(key))
  221. b_text_box_list = get_text_box_obj(area_text_list, area_bbox_list)
  222. log('get_table_by_rule cost: ' + str(time.time()-start_time))
  223. # 根据表格线生成单元格
  224. start_time = time.time()
  225. b_text_box_list, _table_list, _obj_in_table_list = b_table_process(line_list, b_text_box_list, cell_list, table_location)
  226. table_list += _table_list
  227. obj_in_table_list += _obj_in_table_list
  228. log('b_table_process cost: ' + str(time.time()-start_time))
  229. # if not table_list:
  230. # log('table_process not table_list')
  231. # return [], [], []
  232. if not _table_list:
  233. log('table_process not table_list')
  234. continue
  235. # 单元格合并,nsp模型
  236. # 使用hanlp分词,判断上下句是否该合并 顺便拉数据统计
  237. # 1. 上下句ab,ab相连得到c
  238. # 2.1 c分词,若ab相连处合为一个词语,则ab相连
  239. # 2.2 ab相连处不为一个词语,a, b分别分词
  240. # 2.2.1 若b的第一个词,从其中分第一个字给a,然后
  241. # near_col_list = []
  242. # table = _table_list[0].get('table')
  243. # col_cnt = len(table[0])
  244. # for c_cnt in range(col_cnt):
  245. # for i in range(len(table)-1):
  246. # t = table[i][c_cnt].get('text')
  247. # next_t = table[i+1][c_cnt].get('text')
  248. # if t and next_t:
  249. # near_col_list.append([t, next_t])
  250. # elif t and next_t == '':
  251. # if i+2 <= len(table)-1:
  252. # next_2_t = table[i+2][c_cnt].get('text')
  253. # near_col_list.append([t, next_2_t])
  254. #
  255. # is_next_list = nsp_predict(near_col_list, has_label=False)
  256. #
  257. # next_index = 0
  258. # for c_cnt in range(col_cnt):
  259. # # 先把一列里的需合并的打上标签
  260. # for i in range(len(table)-1):
  261. # t = table[i][c_cnt].get('text')
  262. # next_t = table[i+1][c_cnt].get('text')
  263. # if t and next_t:
  264. # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
  265. # next_index += 1
  266. # elif t and next_t == '':
  267. # if i+2 <= len(table)-1:
  268. # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
  269. # table[i+2][c_cnt]['is_next'] = is_next_list[next_index]
  270. # next_index += 1
  271. #
  272. # first_col = None
  273. # for i in range(len(table)):
  274. # if table[i][c_cnt].get('is_next'):
  275. # if first_col is None:
  276. # first_col = table[i-1][c_cnt]
  277. # first_col['text'] += table[i][c_cnt].get('text')
  278. # first_col['rowspan'] += 1
  279. # else:
  280. # first_col = None
  281. #
  282. # # 删除标签为True的
  283. # new_table = []
  284. # for row in table:
  285. # new_row = []
  286. # for col in row:
  287. # if col.get('is_next'):
  288. # continue
  289. # new_row.append(col)
  290. # new_table.append(new_row)
  291. #
  292. # _table_list[0]['table'] = new_table
  293. log('get_table finish ' + str(time.time() - start_time_all))
  294. return text_box_list, table_list, obj_in_table_list