import re import time import traceback import cv2 from pdfminer.layout import LTLine # from botr.nsp.predict import nsp_predict from botr.rules.get_table_by_rules import get_table_by_rule from botr.utils import line_iou, get_table_iou from format_convert.convert_need_interface import from_yolo_interface from format_convert.utils import log, np2bytes def b_table_process(list_line, list_text_boxes, list_cell, table_location): def merge_textbox(textbox_list, in_objs): delete_obj = [] threshold = 5 textbox_list.sort(key=lambda x:x.bbox[0]) for k in range(len(textbox_list)): tb1 = textbox_list[k] if tb1 not in in_objs and tb1 not in delete_obj: for m in range(k+1, len(textbox_list)): tb2 = textbox_list[m] if tb2 in in_objs: continue if abs(tb1.bbox[1]-tb2.bbox[1]) <= threshold \ and abs(tb1.bbox[3]-tb2.bbox[3]) <= threshold: if tb1.bbox[0] <= tb2.bbox[0]: tb1.text = tb1.text + tb2.text else: tb1.text = tb2.text + tb1.text tb1.bbox[0] = min(tb1.bbox[0], tb2.bbox[0]) tb1.bbox[2] = max(tb1.bbox[2], tb2.bbox[2]) delete_obj.append(tb2) for _obj in delete_obj: if _obj in textbox_list: textbox_list.remove(_obj) return textbox_list try: if list_line: from format_convert.convert_tree import TableLine list_lines = [] for line in list_line: list_lines.append(LTLine(1, (line[0], line[1]), (line[2], line[3]))) # 先拿出在表格区域里的TextBox area_list_text_boxes = [] threshold = 7 for t_b in list_text_boxes: bbox = t_b.bbox if table_location[1] - threshold <= bbox[1] <= bbox[3] <= table_location[3] + threshold: area_list_text_boxes.append(t_b) # 对TextBox进行分行,否则同样一行有些框偏上有些偏下,影响文本顺序 area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2])) current_y = area_list_text_boxes[0].bbox[1] threshold = 2. for t_b in area_list_text_boxes: bbox = t_b.bbox if current_y - threshold <= bbox[1] <= current_y + threshold: t_b.bbox[1] = current_y else: current_y = bbox[1] area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2])) # list_cell 转化为 LineTable形式 tables = [] obj_in_table = [] table_dict = {'bbox': table_location} row_list = [] for row in list_cell: col_list = [] for col in row: col_dict = {'bbox': (col[0][0], col[0][1], col[1][0], col[1][1]), 'rowspan': 1, 'columnspan': 1, 'text': ''} for t_b in area_list_text_boxes: if t_b in obj_in_table: continue text = re.sub('\s', '', t_b.text) bbox = t_b.bbox iou = get_table_iou(col[0][0], col[0][1], col[1][0], col[1][1], bbox[0], bbox[1], bbox[2], bbox[3]) if iou >= 0.3: col_dict['text'] += text obj_in_table.append(t_b) col_list.append(col_dict) row_list.append(col_list) table_dict['table'] = row_list tables.append(table_dict) # print('b_table_process tables', tables) # 合并同一行textbox # list_text_boxes = merge_textbox(list_text_boxes, obj_in_table) return list_text_boxes, tables, obj_in_table else: return list_text_boxes, [], set() except: traceback.print_exc() return [-8], [-8], [-8] def get_text_box_obj(_text_list, _bbox_list): from format_convert.convert_tree import TextBox _text_box_list = [] for i in range(len(_bbox_list)): bbox = _bbox_list[i] b_text = _text_list[i] _text_box_list.append(TextBox([bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]], b_text)) return _text_box_list def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0): log('start') # 检测无边框表格 start_time_all = time.time() start_time = time.time() img_bytes = np2bytes(img) b_table_list = from_yolo_interface(img_bytes) log('yolo detect cost: ' + str(time.time()-start_time)) b_table_list = b_table_list[0] if not b_table_list: log('detect not b_table_list') return [], [], [] # if show: # for b_table in b_table_list: # # for line in b_table: # cv2.rectangle(img, (int(b_table[0]), int(b_table[1])), (int(b_table[2]), int(b_table[3])), # (0, 0, 255), 2) # cv2.namedWindow('b_table', cv2.WINDOW_NORMAL) # cv2.imshow('b_table', img) # cv2.waitKey(0) if show: print('b_table_list', b_table_list) print('table_list', table_list) # 排除otr结果 b_table_location_list = [] for b_table in b_table_list: # print('b_table', b_table) min_x, min_y = 1000000, 1000000 max_x, max_y = 0, 0 # for line in b_table: if b_table[1] < min_y: min_y = b_table[1] if b_table[3] > max_y: max_y = b_table[3] if b_table[0] < min_x: min_x = b_table[0] if b_table[2] > max_x: max_x = b_table[2] b_loc = [min_x, min_y, max_x, max_y, b_table[4]] inter_flag = False for table in table_list: loc = table.get('bbox') rows = table.get('table') iou = line_iou([[0, loc[1]], [0, loc[3]]], [[0, b_loc[1]], [0, b_loc[3]]], axis=1) if iou > 0.3: # if len(rows) <= 1: # if loc[1] < b_loc[1] < loc[3] < b_loc[3]: # b_loc[1] = loc[3] # if b_loc[1] < loc[1] < b_loc[3] < loc[3]: # b_loc[3] = loc[1] # continue inter_flag = True # cv2.rectangle(img, [int(loc[0]), int(loc[1])], [int(loc[2]), int(loc[3])], (0, 0, 255)) # cv2.rectangle(img, [int(b_loc[0]), int(b_loc[1])], [int(b_loc[2]), int(b_loc[3])], (0, 0, 255)) # cv2.imshow('inter', img) # cv2.waitKey(0) break if not inter_flag: b_table_location_list.append(b_loc) if not b_table_location_list: log('except otr, not b_table_location_list') return [], [], [] if show: print('len(b_table_location_list)', len(b_table_location_list)) # 排除有重合的,取概率大的 if len(b_table_location_list) > 1: temp_list = [] used_b_loc = [] for i in range(len(b_table_location_list)): b_loc1 = b_table_location_list[i] if b_loc1 in used_b_loc: continue inter_flag = False for j in range(i+1, len(b_table_location_list)): b_loc2 = b_table_location_list[j] iou = line_iou([[0, b_loc1[1]], [0, b_loc1[3]]], [[0, b_loc2[1]], [0, b_loc2[3]]], axis=1) if show: print('iou2', iou) if iou > 0.3: inter_flag = True break if inter_flag: used_b_loc.append(b_loc2) if b_loc1[4] >= b_loc2[4]: temp_list.append(b_loc1[:4]) else: temp_list.append(b_loc2[:4]) else: temp_list.append(b_loc1[:4]) b_table_location_list = temp_list if show: for b_loc in b_table_location_list: cv2.rectangle(img, (int(b_loc[0]), int(b_loc[1])), (int(b_loc[2]), int(b_loc[3])), (0, 0, 255), 2) cv2.namedWindow('b_table_no_otr', cv2.WINDOW_NORMAL) cv2.imshow('b_table_no_otr', img) cv2.waitKey(0) table_list = [] obj_in_table_list = [] # print('len(b_table_location_list)', len(b_table_location_list)) for b_loc in b_table_location_list: area_text_list = [] area_bbox_list = [] threshold = 5 for i, bbox in enumerate(bbox_list): if b_loc[1] - threshold <= bbox[0][1] <= bbox[2][1] <= b_loc[3] + threshold: area_bbox_list.append(bbox) area_text_list.append(text_list[i]) # 根据ocr bbox,规则生成表格线 start_time = time.time() line_list, cell_list, table_location, bbox_text_dict = get_table_by_rule(img, area_text_list, area_bbox_list, b_loc, show=show) if not table_location: log('get_table_by_rule not table_location') continue # 获取最新的text_list, bbox_list area_text_list, area_bbox_list = [], [] for key in bbox_text_dict.keys(): area_bbox_list.append(eval(key)) area_text_list.append(bbox_text_dict.get(key)) b_text_box_list = get_text_box_obj(area_text_list, area_bbox_list) log('get_table_by_rule cost: ' + str(time.time()-start_time)) # 根据表格线生成单元格 start_time = time.time() b_text_box_list, _table_list, _obj_in_table_list = b_table_process(line_list, b_text_box_list, cell_list, table_location) table_list += _table_list obj_in_table_list += _obj_in_table_list log('b_table_process cost: ' + str(time.time()-start_time)) # if not table_list: # log('table_process not table_list') # return [], [], [] if not _table_list: log('table_process not table_list') continue # 单元格合并,nsp模型 # 使用hanlp分词,判断上下句是否该合并 顺便拉数据统计 # 1. 上下句ab,ab相连得到c # 2.1 c分词,若ab相连处合为一个词语,则ab相连 # 2.2 ab相连处不为一个词语,a, b分别分词 # 2.2.1 若b的第一个词,从其中分第一个字给a,然后 # near_col_list = [] # table = _table_list[0].get('table') # col_cnt = len(table[0]) # for c_cnt in range(col_cnt): # for i in range(len(table)-1): # t = table[i][c_cnt].get('text') # next_t = table[i+1][c_cnt].get('text') # if t and next_t: # near_col_list.append([t, next_t]) # elif t and next_t == '': # if i+2 <= len(table)-1: # next_2_t = table[i+2][c_cnt].get('text') # near_col_list.append([t, next_2_t]) # # is_next_list = nsp_predict(near_col_list, has_label=False) # # next_index = 0 # for c_cnt in range(col_cnt): # # 先把一列里的需合并的打上标签 # for i in range(len(table)-1): # t = table[i][c_cnt].get('text') # next_t = table[i+1][c_cnt].get('text') # if t and next_t: # table[i+1][c_cnt]['is_next'] = is_next_list[next_index] # next_index += 1 # elif t and next_t == '': # if i+2 <= len(table)-1: # table[i+1][c_cnt]['is_next'] = is_next_list[next_index] # table[i+2][c_cnt]['is_next'] = is_next_list[next_index] # next_index += 1 # # first_col = None # for i in range(len(table)): # if table[i][c_cnt].get('is_next'): # if first_col is None: # first_col = table[i-1][c_cnt] # first_col['text'] += table[i][c_cnt].get('text') # first_col['rowspan'] += 1 # else: # first_col = None # # # 删除标签为True的 # new_table = [] # for row in table: # new_row = [] # for col in row: # if col.get('is_next'): # continue # new_row.append(col) # new_table.append(new_row) # # _table_list[0]['table'] = new_table log('get_table finish ' + str(time.time() - start_time_all)) return text_box_list, table_list, obj_in_table_list