123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- import re
- import time
- import traceback
- import cv2
- from pdfminer.layout import LTLine
- # from botr.nsp.predict import nsp_predict
- from botr.rules.get_table_by_rules import get_table_by_rule
- from botr.utils import line_iou, get_table_iou
- from format_convert.convert_need_interface import from_yolo_interface
- from format_convert.utils import log, np2bytes
- def b_table_process(list_line, list_text_boxes, list_cell, table_location):
- def merge_textbox(textbox_list, in_objs):
- delete_obj = []
- threshold = 5
- textbox_list.sort(key=lambda x:x.bbox[0])
- for k in range(len(textbox_list)):
- tb1 = textbox_list[k]
- if tb1 not in in_objs and tb1 not in delete_obj:
- for m in range(k+1, len(textbox_list)):
- tb2 = textbox_list[m]
- if tb2 in in_objs:
- continue
- if abs(tb1.bbox[1]-tb2.bbox[1]) <= threshold \
- and abs(tb1.bbox[3]-tb2.bbox[3]) <= threshold:
- if tb1.bbox[0] <= tb2.bbox[0]:
- tb1.text = tb1.text + tb2.text
- else:
- tb1.text = tb2.text + tb1.text
- tb1.bbox[0] = min(tb1.bbox[0], tb2.bbox[0])
- tb1.bbox[2] = max(tb1.bbox[2], tb2.bbox[2])
- delete_obj.append(tb2)
- for _obj in delete_obj:
- if _obj in textbox_list:
- textbox_list.remove(_obj)
- return textbox_list
- try:
- if list_line:
- from format_convert.convert_tree import TableLine
- list_lines = []
- for line in list_line:
- list_lines.append(LTLine(1, (line[0], line[1]), (line[2], line[3])))
- # 先拿出在表格区域里的TextBox
- area_list_text_boxes = []
- threshold = 7
- for t_b in list_text_boxes:
- bbox = t_b.bbox
- if table_location[1] - threshold <= bbox[1] <= bbox[3] <= table_location[3] + threshold:
- area_list_text_boxes.append(t_b)
- # 对TextBox进行分行,否则同样一行有些框偏上有些偏下,影响文本顺序
- area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
- current_y = area_list_text_boxes[0].bbox[1]
- current_y2 = area_list_text_boxes[0].bbox[3]
- # threshold = 2.
- threshold = max(2., 1/3 * abs(current_y2 - current_y))
- for t_b in area_list_text_boxes:
- bbox = t_b.bbox
- if current_y - threshold <= bbox[1] <= current_y + threshold:
- t_b.bbox[1] = current_y
- else:
- current_y = bbox[1]
- area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
- # list_cell 转化为 LineTable形式
- tables = []
- obj_in_table = []
- table_dict = {'bbox': table_location}
- row_list = []
- for row in list_cell:
- col_list = []
- for col in row:
- col_dict = {'bbox': (col[0][0], col[0][1], col[1][0], col[1][1]),
- 'rowspan': 1, 'columnspan': 1, 'text': ''}
- for t_b in area_list_text_boxes:
- if t_b in obj_in_table:
- continue
- text = re.sub('\s', '', t_b.text)
- bbox = t_b.bbox
- iou = get_table_iou(col[0][0], col[0][1], col[1][0], col[1][1],
- bbox[0], bbox[1], bbox[2], bbox[3])
- if iou >= 0.3:
- col_dict['text'] += text
- obj_in_table.append(t_b)
- col_list.append(col_dict)
- row_list.append(col_list)
- table_dict['table'] = row_list
- tables.append(table_dict)
- # print('b_table_process tables', tables)
- # 合并同一行textbox
- # list_text_boxes = merge_textbox(list_text_boxes, obj_in_table)
- return list_text_boxes, tables, obj_in_table
- else:
- return list_text_boxes, [], set()
- except:
- traceback.print_exc()
- return [-8], [-8], [-8]
- def get_text_box_obj(_text_list, _bbox_list):
- from format_convert.convert_tree import TextBox
- _text_box_list = []
- for i in range(len(_bbox_list)):
- bbox = _bbox_list[i]
- b_text = _text_list[i]
- _text_box_list.append(TextBox([bbox[0][0], bbox[0][1],
- bbox[2][0], bbox[2][1]], b_text))
- return _text_box_list
- def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
- log('start')
- # 检测无边框表格
- start_time_all = time.time()
- start_time = time.time()
- img_bytes = np2bytes(img)
- b_table_list = from_yolo_interface(img_bytes)
- log('yolo detect cost: ' + str(time.time()-start_time))
- b_table_list = b_table_list[0]
- if not b_table_list:
- log('detect not b_table_list')
- return [], [], []
- # if show:
- # for b_table in b_table_list:
- # # for line in b_table:
- # cv2.rectangle(img, (int(b_table[0]), int(b_table[1])), (int(b_table[2]), int(b_table[3])),
- # (0, 0, 255), 2)
- # cv2.namedWindow('b_table', cv2.WINDOW_NORMAL)
- # cv2.imshow('b_table', img)
- # cv2.waitKey(0)
- if show:
- print('b_table_list', b_table_list)
- print('table_list', table_list)
- # 排除otr结果
- b_table_location_list = []
- for b_table in b_table_list:
- # print('b_table', b_table)
- min_x, min_y = 1000000, 1000000
- max_x, max_y = 0, 0
- # for line in b_table:
- if b_table[1] < min_y:
- min_y = b_table[1]
- if b_table[3] > max_y:
- max_y = b_table[3]
- if b_table[0] < min_x:
- min_x = b_table[0]
- if b_table[2] > max_x:
- max_x = b_table[2]
- b_loc = [min_x, min_y, max_x, max_y, b_table[4]]
- inter_flag = False
- for table in table_list:
- loc = table.get('bbox')
- rows = table.get('table')
- iou = line_iou([[0, loc[1]], [0, loc[3]]], [[0, b_loc[1]], [0, b_loc[3]]], axis=1)
- if iou > 0.3:
- # if len(rows) <= 1:
- # if loc[1] < b_loc[1] < loc[3] < b_loc[3]:
- # b_loc[1] = loc[3]
- # if b_loc[1] < loc[1] < b_loc[3] < loc[3]:
- # b_loc[3] = loc[1]
- # continue
- inter_flag = True
- # cv2.rectangle(img, [int(loc[0]), int(loc[1])], [int(loc[2]), int(loc[3])], (0, 0, 255))
- # cv2.rectangle(img, [int(b_loc[0]), int(b_loc[1])], [int(b_loc[2]), int(b_loc[3])], (0, 0, 255))
- # cv2.imshow('inter', img)
- # cv2.waitKey(0)
- break
- if not inter_flag:
- b_table_location_list.append(b_loc)
- if not b_table_location_list:
- log('except otr, not b_table_location_list')
- return [], [], []
- if show:
- print('len(b_table_location_list)', len(b_table_location_list))
- # 排除有重合的,取概率大的
- if len(b_table_location_list) > 1:
- temp_list = []
- used_b_loc = []
- for i in range(len(b_table_location_list)):
- b_loc1 = b_table_location_list[i]
- if b_loc1 in used_b_loc:
- continue
- inter_flag = False
- for j in range(i+1, len(b_table_location_list)):
- b_loc2 = b_table_location_list[j]
- iou = line_iou([[0, b_loc1[1]], [0, b_loc1[3]]], [[0, b_loc2[1]], [0, b_loc2[3]]], axis=1)
- if show:
- print('iou2', iou)
- if iou > 0.3:
- inter_flag = True
- break
- if inter_flag:
- used_b_loc.append(b_loc2)
- if b_loc1[4] >= b_loc2[4]:
- temp_list.append(b_loc1[:4])
- else:
- temp_list.append(b_loc2[:4])
- else:
- temp_list.append(b_loc1[:4])
- b_table_location_list = temp_list
- if show:
- for b_loc in b_table_location_list:
- cv2.rectangle(img, (int(b_loc[0]), int(b_loc[1])), (int(b_loc[2]), int(b_loc[3])),
- (0, 0, 255), 2)
- cv2.namedWindow('b_table_no_otr', cv2.WINDOW_NORMAL)
- cv2.imshow('b_table_no_otr', img)
- cv2.waitKey(0)
- table_list = []
- obj_in_table_list = []
- # print('len(b_table_location_list)', len(b_table_location_list))
- for b_loc in b_table_location_list:
- area_text_list = []
- area_bbox_list = []
- threshold = 5
- for i, bbox in enumerate(bbox_list):
- if b_loc[1] - threshold <= bbox[0][1] <= bbox[2][1] <= b_loc[3] + threshold:
- area_bbox_list.append(bbox)
- area_text_list.append(text_list[i])
- # 根据ocr bbox,规则生成表格线
- start_time = time.time()
- line_list, cell_list, table_location, bbox_text_dict = get_table_by_rule(img, area_text_list, area_bbox_list, b_loc, show=show)
- if not table_location:
- log('get_table_by_rule not table_location')
- continue
- # 获取最新的text_list, bbox_list
- area_text_list, area_bbox_list = [], []
- for key in bbox_text_dict.keys():
- area_bbox_list.append(eval(key))
- area_text_list.append(bbox_text_dict.get(key))
- b_text_box_list = get_text_box_obj(area_text_list, area_bbox_list)
- log('get_table_by_rule cost: ' + str(time.time()-start_time))
- # 根据表格线生成单元格
- start_time = time.time()
- b_text_box_list, _table_list, _obj_in_table_list = b_table_process(line_list, b_text_box_list, cell_list, table_location)
- table_list += _table_list
- obj_in_table_list += _obj_in_table_list
- log('b_table_process cost: ' + str(time.time()-start_time))
- # if not table_list:
- # log('table_process not table_list')
- # return [], [], []
- if not _table_list:
- log('table_process not table_list')
- continue
- # 单元格合并,nsp模型
- # 使用hanlp分词,判断上下句是否该合并 顺便拉数据统计
- # 1. 上下句ab,ab相连得到c
- # 2.1 c分词,若ab相连处合为一个词语,则ab相连
- # 2.2 ab相连处不为一个词语,a, b分别分词
- # 2.2.1 若b的第一个词,从其中分第一个字给a,然后
- # near_col_list = []
- # table = _table_list[0].get('table')
- # col_cnt = len(table[0])
- # for c_cnt in range(col_cnt):
- # for i in range(len(table)-1):
- # t = table[i][c_cnt].get('text')
- # next_t = table[i+1][c_cnt].get('text')
- # if t and next_t:
- # near_col_list.append([t, next_t])
- # elif t and next_t == '':
- # if i+2 <= len(table)-1:
- # next_2_t = table[i+2][c_cnt].get('text')
- # near_col_list.append([t, next_2_t])
- #
- # is_next_list = nsp_predict(near_col_list, has_label=False)
- #
- # next_index = 0
- # for c_cnt in range(col_cnt):
- # # 先把一列里的需合并的打上标签
- # for i in range(len(table)-1):
- # t = table[i][c_cnt].get('text')
- # next_t = table[i+1][c_cnt].get('text')
- # if t and next_t:
- # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
- # next_index += 1
- # elif t and next_t == '':
- # if i+2 <= len(table)-1:
- # table[i+1][c_cnt]['is_next'] = is_next_list[next_index]
- # table[i+2][c_cnt]['is_next'] = is_next_list[next_index]
- # next_index += 1
- #
- # first_col = None
- # for i in range(len(table)):
- # if table[i][c_cnt].get('is_next'):
- # if first_col is None:
- # first_col = table[i-1][c_cnt]
- # first_col['text'] += table[i][c_cnt].get('text')
- # first_col['rowspan'] += 1
- # else:
- # first_col = None
- #
- # # 删除标签为True的
- # new_table = []
- # for row in table:
- # new_row = []
- # for col in row:
- # if col.get('is_next'):
- # continue
- # new_row.append(col)
- # new_table.append(new_row)
- #
- # _table_list[0]['table'] = new_table
- log('get_table finish ' + str(time.time() - start_time_all))
- return text_box_list, table_list, obj_in_table_list
|