import copy import math import random import time import traceback import numpy as np import cv2 from matplotlib import pyplot as plt from pdfminer.layout import LTTextContainer, LTRect, LTCurve, LTLine from scipy.stats import linregress from shapely.geometry import LineString from format_convert.utils import log, bbox_iou from otr.table_line_new import table_line_pdf_post_process page_w = 100 page_h = 100 def _plot(_line_list, title, mode=1, show=1): if not show: return for _line in _line_list: if mode == 1: x0, y0, x1, y1 = _line.__dict__.get("bbox") elif mode == 2: x0, y0, x1, y1 = _line plt.plot([x0, x1], [y0, y1]) plt.title(title) plt.show() return def is_cross(A, B, C, D): if A[0] == B[0] == C[0] == D[0]: if A[1] <= C[1] <= B[1] or A[1] <= D[1] <= B[1] \ or C[1] <= A[1] <= D[1] or C[1] <= B[1] <= D[1]: return True if A[1] == B[1] == C[1] == D[1]: if A[0] <= C[0] <= B[0] or A[0] <= D[0] <= B[0] \ or C[0] <= A[0] <= D[0] or C[0] <= B[0] <= D[0]: return True line1 = LineString([A, B]) line2 = LineString([C, D]) int_pt = line1.intersection(line2) try: point_of_intersection = int_pt.x, int_pt.y return True except: return False def calculate_k(bbox): x = [bbox[0], bbox[2]] y = [bbox[1], bbox[3]] slope, intercept, r_value, p_value, std_err = linregress(x, y) # print('k', slope) if math.isnan(slope): slope = 0 return slope def line_iou(line1, line2, axis=0): if line1[0][axis] <= line2[0][axis] <= line2[1][axis] <= line1[1][axis]: return 1.0 if line2[0][axis] <= line1[0][axis] <= line1[1][axis] <= line2[1][axis]: return 1.0 inter = min(line1[1][axis], line2[1][axis]) - max(line1[0][axis], line2[0][axis]) # union = max(line1[1][axis], line2[1][axis]) - min(line1[0][axis], line2[0][axis]) union = min(abs(line1[0][axis] - line1[1][axis]), abs(line2[0][axis] - line2[1][axis])) if union in [0, 0.]: iou = 0. else: iou = inter / union return iou def get_cross_line(_line_list, threshold=1, cross_times=0): start_time = time.time() start_time1 = time.time() # 分横线竖线 new_line_list = [] for line in _line_list: if abs(line[0]-line[2]) >= abs(line[1]-line[3]): new_line = [max(0, line[0] - threshold), line[1], min(line[2] + threshold, page_w), line[3]] else: new_line = [line[0], max(0, line[1] - threshold), line[2], min(line[3] + threshold, page_h)] new_line_list.append(new_line) _cross_line_list = [] for i in range(len(new_line_list)): line1 = new_line_list[i] # line1的计算区域 line1_area = [max(0, line1[0]-threshold), max(0, line1[1]-threshold), min(page_w, line1[2]+threshold), min(page_h, line1[3]+threshold)] # line1是横线还是竖线 if abs(line1[0] - line1[2]) >= abs(line1[1]-line1[3]): line1_is_row = 1 else: line1_is_row = 0 _times = 0 for j in range(len(new_line_list)): if i == j: continue line2 = new_line_list[j] if abs(line2[0] - line2[2]) >= abs(line2[1]-line2[3]): line2_is_row = 1 else: line2_is_row = 0 # 十字交叉的横竖线直接判断交点 if line1_is_row ^ line2_is_row: if (line1_is_row and line1[0] <= line2[0] <= line1[2] and line2[1] <= line1[1] <= line2[3]) \ or (line2_is_row and line2[0] <= line1[0] <= line2[2] and line1[1] <= line2[1] <= line1[3]): _times += 1 if _times >= cross_times: _cross_line_list += [line1] break continue # 不在计算区域的直接跳过 if not((line1_area[0] <= line2[0] <= line1_area[2] and line1_area[1] <= line2[1] <= line1_area[3]) or (line1_area[0] <= line2[2] <= line1_area[2] and line1_area[1] <= line2[3] <= line1_area[3]) or ()): continue if is_cross(line1[:2], line1[2:4], line2[:2], line2[2:4]): _times += 1 if _times >= cross_times: _cross_line_list += [line1] break _cross_line_list1 = _cross_line_list # print('get_cross_line new', time.time()-start_time1) # start_time1 = time.time() # # # 根据是否有交点判断表格线 # _cross_line_list = [] # for line1 in _line_list: # if line1 in _cross_line_list: # continue # if abs(line1[2] - line1[0]) > abs(line1[3] - line1[1]): # p1 = [max(0, line1[0] - threshold), line1[1]] # p2 = [min(line1[2] + threshold, page_w), line1[3]] # else: # p1 = [line1[0], max(0, line1[1] - threshold)] # p2 = [line1[2], min(line1[3] + threshold, page_h)] # line1 = [p1[0], p1[1], p2[0], p2[1]] # _times = 0 # for line2 in _line_list: # if abs(line2[2] - line2[0]) > abs(line2[3] - line2[1]): # p3 = [max(0, line2[0] - threshold), line2[1]] # p4 = [min(line2[2] + threshold, page_w), line2[3]] # else: # p3 = [line2[0], max(0, line2[1] - threshold)] # p4 = [line2[2], min(line2[3] + threshold, page_h)] # line2 = [p3[0], p3[1], p4[0], p4[1]] # if line1 == line2: # continue # if is_cross(p1, p2, p3, p4): # _times += 1 # if _times >= cross_times: # _cross_line_list += [line1] # break # # if len(_cross_line_list1) > 0 or len(_cross_line_list) > 0: # print('get_cross_line old', time.time()-start_time1) # print(len(_cross_line_list1), len(_cross_line_list)) log('get_cross_line cost: ' + str(time.time()-start_time)) return _cross_line_list1 def merge_line(_line_list, threshold=2): start_time = time.time() new_line_list = [] # 分列 _line_list.sort(key=lambda x: (x[0], x[1])) cols = [] col = [] current_w = None for line in _line_list: if abs(line[0] - line[2]) > abs(line[1] - line[3]): continue if not col: col.append(line) current_w = line[0] _iou = line_iou([[0, line[1]], [0, line[3]]], [[0, col[0][1]], [0, col[0][3]]], axis=1) if min(line[0], line[2]) - threshold <= current_w <= max(line[0], line[2]) + threshold \ and is_cross(line[0:2], line[2:4], col[-1][0:2], col[-1][2:4]): col.append(line) elif min(line[0], line[2]) - 2*threshold <= current_w <= max(line[0], line[2]) + 2*threshold \ and _iou >= 0.1: col.append(line) else: if col: cols.append(col) col = [line] current_w = line[0] if col: cols.append(col) for col in cols: temp_c = col[0] col_w = col[0][0] for i in range(len(col) - 1): c = col[i] next_c = col[i + 1] if is_cross(c[0:2], c[2:4], next_c[0:2], next_c[2:4]) \ or line_iou([[0, c[1]], [0, c[3]]], [[0, next_c[1]], [0, next_c[3]]], axis=1) >= 0.1: temp_c = [col_w, min(temp_c[1], c[1], c[3], next_c[1], next_c[3]), col_w, max(temp_c[3], c[1], c[3], next_c[1], next_c[3])] else: new_line_list.append(temp_c) temp_c = next_c if not new_line_list or (new_line_list and new_line_list[-1] != temp_c): new_line_list.append(temp_c) # 分行 _line_list.sort(key=lambda x: (x[1], x[0])) rows = [] row = [] current_h = None for line in _line_list: if abs(line[0] - line[2]) < abs(line[1] - line[3]): continue if not row: row = [line] current_h = line[1] if min(line[1], line[3]) - threshold <= current_h <= max(line[1], line[3]) + threshold: row.append(line) else: if row: rows.append(row) row = [line] current_h = line[1] if row: rows.append(row) for row in rows: temp_r = row[0] row_h = row[0][1] for i in range(len(row) - 1): r = row[i] next_r = row[i + 1] # if is_cross(r[0:2], r[2:4], next_r[0:2], next_r[2:4]): if line_iou([r[0:2], r[2:4]], [next_r[0:2], next_r[2:4]], axis=0) >= 0.1: temp_r = [min(temp_r[0], r[0], r[2], next_r[0], next_r[2]), row_h, max(temp_r[2], r[0], r[2], next_r[0], next_r[2]), row_h] else: new_line_list.append(temp_r) temp_r = next_r if not new_line_list or (new_line_list and new_line_list[-1] != temp_r): new_line_list.append(temp_r) log('merge_line1 cost: ' + str(time.time()-start_time)) return new_line_list def remove_outline_no_cross(_line_list): start_time = time.time() row_list = [] col_list = [] for line in _line_list: # 存所有行 if abs(line[0] - line[2]) > abs(line[1] - line[3]): row_list.append(line) # 存所有列 if abs(line[0] - line[2]) < abs(line[1] - line[3]): col_list.append(line) if not col_list: return _line_list # 左右两条边框 col_list.sort(key=lambda x: (x[0], x[1])) left_col = col_list[0] right_col = col_list[-1] # 判断有交点但中间区域无交点 compare_list = [] for col in [left_col, right_col]: add_h = abs(col[1]-col[3]) / 8 center_area = [col[1]+add_h, col[3]-add_h] cross_cnt = 0 center_cross_cnt = 0 center_row_cnt = 0 for row in row_list: if is_cross(row[0:2], row[2:4], col[0:2], col[2:4]): if center_area[0] <= row[1] <= center_area[1]: center_cross_cnt += 1 else: cross_cnt += 1 else: if center_area[0] <= row[1] <= center_area[1]: center_row_cnt += 1 compare_list.append([cross_cnt, center_cross_cnt, center_row_cnt]) _flag = True for c in compare_list: if c[0] >= 2 and c[1] == 0 and c[2] >= 2: continue _flag = False # print('compare_list', compare_list) if _flag and compare_list[0][1] == compare_list[1][1] \ and compare_list[0][2] == compare_list[1][2]: for col in [left_col, right_col]: if col in _line_list: _line_list.remove(col) log('merge_line cost: ' + str(time.time()-start_time)) return _line_list def table_line_pdf(layout, page_no, show=0): # print('table_line_pdf show ', show) log('into table_line_pdf') page_h = layout.height page_w = layout.width # 限制page_h, page_w if page_h > 10000 or page_w > 10000: log('1 page_h or page_w > 10000 ' + str(page_h) + ' ' + str(page_w)) return [] line_list = [] lt_text_container_list = [] lt_rect_list = [] lt_line_list = [] lt_curve_list = [] line_rect_list = [] non_line_rect_list = [] delete_lt_rect_list = [] start_time = time.time() # 从layout中提取各种对象:文本框、矩形框、曲线、线 min_y = 10000 max_x, max_y = 0, 0 threshold = 2 for element in layout: if isinstance(element, LTTextContainer): lt_text_container_list.append(element) elif isinstance(element, LTRect): lt_rect_list.append(element) # 筛选出线形矩形和非线形矩形 if (element.height <= threshold) ^ (element.width <= threshold): # print('line_rect', element.stroke, element.stroking_color, element.non_stroking_color, element.fill, element.height * element.width, element.height, element.width) line_rect_list.append(element) elif element.height > threshold and element.width > threshold: # print('non_line_rect', element.stroke, element.stroking_color, element.non_stroking_color, element.fill, element.height * element.width, element.height, element.width) non_line_rect_list.append(element) else: delete_lt_rect_list.append(element) # 获取最大尺寸 if element.bbox[1] <= min_y: min_y = element.bbox[1] if element.bbox[3] <= min_y: min_y = element.bbox[3] if element.bbox[1] > max_y: max_y = element.bbox[1] if element.bbox[3] > max_y: max_y = element.bbox[3] if element.bbox[0] > max_x: max_x = element.bbox[0] if element.bbox[2] > max_x: max_x = element.bbox[2] elif isinstance(element, LTLine): lt_line_list.append(element) elif isinstance(element, LTCurve): lt_curve_list.append(element) if show: print('len(lt_text_container_list)', len(lt_text_container_list)) print('len(lt_rect_list)', len(lt_rect_list)) print('len(lt_line_list)', len(lt_line_list)) print('len(lt_curve_list)', len(lt_curve_list)) print('len(line_rect_list)', len(line_rect_list)) print('len(non_line_rect_list)', len(non_line_rect_list)) print('len(delete_lt_rect_list)', len(delete_lt_rect_list)) if max_y > page_h: page_h = max_y + 20 if max_x > page_w: page_w = max_x + 20 # 限制page_h, page_w if page_h > 10000 or page_w > 10000: log('2 page_h or page_w > 10000 ' + str(page_h) + ' ' + str(page_w)) return [] globals().update({'page_h': page_h}) globals().update({'page_w': page_w}) # 矩形框y有负数 if min_y < 0: for lt_rect in lt_rect_list: if lt_rect.y0 < 0 or lt_rect.y1 < 0: new_y0 = 10 if lt_rect.y0 < 0 else lt_rect.y0 new_y1 = 10 if lt_rect.y1 < 0 else lt_rect.y1 lt_rect.set_bbox((lt_rect.x0, new_y0, lt_rect.x1, new_y1)) _plot([x.bbox for x in lt_rect_list + lt_line_list], 'get_page_lines start', mode=2, show=show) # 合并矩形框 # for i in range(len(non_line_rect_list)): # lt_rect1 = non_line_rect_list[i] # b1 = lt_rect1.bbox # if lt_rect1 in delete_lt_rect_list: # continue # for j in range(i+1, len(non_line_rect_list)): # lt_rect2 = non_line_rect_list[j] # b2 = lt_rect2.bbox # if lt_rect2 in delete_lt_rect_list: # continue # if bbox_iou(b1, b2, False) >= 0.5: # delete_lt_rect_list.append(lt_rect2) # # # 非线形矩形若与线形矩形距离较近,则删除 # threshold = 5 # for n_rect in non_line_rect_list: # if n_rect in delete_lt_rect_list: # continue # middle_x = (n_rect.x0 + n_rect.x1) / 2 # middle_y = (n_rect.y0 + n_rect.y1) / 2 # for rect in line_rect_list: # if rect in delete_lt_rect_list: # continue # if rect.height >= rect.width: # if n_rect.width / 2 - threshold <= abs(rect.x0 - middle_x) <= n_rect.width / 2 + threshold: # delete_lt_rect_list.append(n_rect) # else: # if n_rect.height / 2 - threshold <= abs(rect.y0 - middle_y) <= n_rect.height / 2 + threshold: # delete_lt_rect_list.append(n_rect) # 寻找每个文本框对应的最小矩形框 text_lt_rect_list = [] # for text_lt_rect in lt_text_container_list: # text_box = text_lt_rect.bbox # contain_iou_list = [] # # min_area = 1000000 # min_lt_rect = None # for lt_rect in non_line_rect_list: # _bbox = lt_rect.bbox # # if lt_rect in delete_lt_rect_list: # continue # if lt_rect in text_lt_rect_list: # continue # if lt_rect.height <= 5 or lt_rect.width <= 5: # continue # # # 如果文本框与矩形框有交集,则直接删除 # if (text_box[0] <= _bbox[0] <= text_box[2] or text_box[0] <= _bbox[2] <= text_box[2]) \ # and (text_box[1] <= _bbox[1] <= text_box[3] or text_box[1] <= _bbox[3] <= text_box[3]): # text_lt_rect_list.append(lt_rect) # continue # # _area = abs(_bbox[2] - _bbox[0]) * abs(_bbox[3] - _bbox[1]) # _iou = bbox_iou(_bbox, text_box, False) # if _iou >= 0.3 and _area < min_area: # min_area = _area # min_lt_rect = lt_rect # # else: # # contain_iou = bbox_iou(_bbox, text_box, True) # # contain_iou_list.append([lt_rect, contain_iou]) # # if min_lt_rect is not None: # text_lt_rect_list.append(min_lt_rect) # # else: # # # 找不到就放低条件,计算iou时包含即为1 # # contain_iou_list.sort(key=lambda x: x[1]) # # text_lt_rect_list.append(contain_iou_list[-1][0]) delete_lt_rect_list += text_lt_rect_list text_line_list = [] for lt_line in lt_text_container_list: _b = lt_line.bbox if abs(_b[0]-_b[2]) >= abs(_b[1]-_b[3]): text_line_list += [[_b[0], _b[1], _b[2], _b[1]], [_b[0], _b[3], _b[2], _b[3]]] else: text_line_list += [[_b[0], _b[1], _b[0], _b[3]], [_b[2], _b[1], _b[2], _b[3]]] _plot(text_line_list, 'lt_text_container_list', mode=2, show=show) # 从线对象提取线 for lt_line in lt_line_list+lt_curve_list: _b = lt_line.bbox if lt_line.height > 10 or lt_line.width > 10: if lt_line.height >= lt_line.width: line_list += [[_b[0], _b[1], _b[0], _b[3]], [_b[2], _b[1], _b[2], _b[3]]] else: line_list += [[_b[0], _b[1], _b[2], _b[1]], [_b[0], _b[3], _b[2], _b[3]]] _plot(line_list, 'lt_line_list+lt_curve_list', mode=2, show=show) # 从线形矩形框提取线 for lt_rect in line_rect_list: if lt_rect in delete_lt_rect_list: continue _b = lt_rect.bbox if abs(_b[0]-_b[2]) >= abs(_b[1]-_b[3]): line_list += [[_b[0], _b[1], _b[2], _b[1]], [_b[0], _b[3], _b[2], _b[3]]] else: line_list += [[_b[0], _b[1], _b[0], _b[3]], [_b[2], _b[1], _b[2], _b[3]]] _plot(line_list, 'line_rect_list', mode=2, show=show) # min_x, min_y = 10000, 10000 # max_x, max_y = 0, 0 # for _b in line_list: # min_x = _b[0] if _b[0] < min_x else min_x # max_x = _b[2] if _b[2] > max_x else max_x # min_y = _b[1] if _b[1] < min_y else min_y # max_y = _b[3] if _b[3] > max_y else max_y # 从普通矩形框提取线,区分描边颜色,排除无色的 # threshold = 10 # img = np.full([int(max_x)+10, int(max_y)+10, 3], 255, dtype=np.uint8) threshold = 0.3 for lt_rect in non_line_rect_list: if lt_rect in delete_lt_rect_list: continue _b = lt_rect.bbox if type(lt_rect.non_stroking_color) in [tuple, list]: continue_flag = 0 for t in lt_rect.non_stroking_color: try: if float(t) >= threshold: continue_flag = 1 break except: traceback.print_exc() continue if continue_flag: continue elif lt_rect.non_stroking_color is not None and float(lt_rect.non_stroking_color) >= threshold: continue # if max_y != 10000 and min_y != 0: # if (_b[3] - max_y >= threshold and _b[2] - max_x >= threshold): # print('_b[3] - max_y >= threshold', _b[3], max_y, _b[2], max_x) # continue # if abs(_b[3] - _b[1]) * abs(_b[2] - _b[0]) >= 1 / 10 * abs(max_y - min_y) * abs(max_x - min_x): # print('>= 1 / 10', _b[3], _b[1], _b[2], _b[0], max_x, max_y) # continue # contain_flag = 0 # for lt_rect2 in non_line_rect_list: # if lt_rect == lt_rect2: # continue # _b2 = lt_rect2.bbox # if bbox_iou(_b, _b2) >= 0.9: # contain_flag = 1 # if _b2[0] <= _b[0] <= _b[2] <= _b2[2] and _b2[1] <= _b[1] <= _b[3] <= _b2[3]: # contain_flag = 1 # if contain_flag: # continue line_list += [[_b[0], _b[1], _b[0], _b[3]], [_b[0], _b[1], _b[2], _b[1]], [_b[2], _b[1], _b[2], _b[3]], [_b[0], _b[3], _b[2], _b[3]]] # cv2.rectangle(img, (int(_b[0]), int(_b[1])), (int(_b[2]), int(_b[3])), [random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)]) # cv2.imshow('img', img) # cv2.waitKey(0) _plot(line_list, 'non_line_rect_list', mode=2, show=show) if not line_list: return [] # 去重 line_list = [str(x) for x in line_list] line_list = list(set(line_list)) line_list = [eval(x) for x in line_list] # 合并线 line_list = merge_line(line_list) if show: print('get_page_lines len(line_list)', len(line_list)) _plot(line_list, 'line_list+bias_line_list', mode=2, show=show) # 根据是否有交点判断表格线 cross_line_list = get_cross_line(line_list, threshold=2, cross_times=1) if show: print('get_page_lines len(cross_line_list)', len(cross_line_list)) _plot(cross_line_list, 'get_cross_line', mode=2, show=show) # 删除最外层嵌套边框 cross_line_list = remove_outline_no_cross(cross_line_list) # 复用otr的部分后处理,补线 cross_line_list = table_line_pdf_post_process(cross_line_list, page_w, page_h) _plot(cross_line_list, 'cross_line_process1', mode=2, show=show) # 有过短的横线与过短的竖线交点 short_line_list = [] for line in cross_line_list: if line[1] == line[3] and abs(line[2] - line[0]) <= 30: short_line_list.append(line) if line[0] == line[2] and abs(line[3] - line[1]) <= 30: short_line_list.append(line) for line in short_line_list: for line2 in short_line_list: if line == line2: continue if is_cross(line[:2], line[2:4], line2[:2], line2[2:4]): if line in cross_line_list: cross_line_list.remove(line) if line2 in cross_line_list: cross_line_list.remove(line2) # print('len(temp_list), len(cross_line_list)', len(temp_list), len(cross_line_list)) # if len(temp_list) != len(cross_line_list): # cross_line_list = table_line_pdf_post_process(temp_list, page_w, page_h) # show if show: print('len(cross_line_list)', len(cross_line_list)) _plot(cross_line_list, 'cross_line_process2', mode=2, show=show) lt_line_list = [] for line in cross_line_list: lt_line_list.append(LTLine(1, (float(line[0]), float(line[1])), (float(line[2]), float(line[3])))) log("pdf page %s has %s lines cost: %s" % (str(page_no), str(len(lt_line_list)), str(time.time()-start_time))) return lt_line_list