Parcourir la source

1.无边框表格优化
2.有边框表格优化
3.表格ocr文本分割
4.新增pdf指定页码提取
5.新增最终结果长度设置

fangjiasheng il y a 1 an
Parent
commit
5247cc225e

+ 74 - 22
botr/extract_table.py

@@ -1,3 +1,4 @@
+import re
 import time
 import traceback
 import cv2
@@ -74,10 +75,10 @@ def b_table_process(list_line, list_text_boxes, list_cell, table_location):
                     for t_b in area_list_text_boxes:
                         if t_b in obj_in_table:
                             continue
-                        text = t_b.text
+                        text = re.sub('\s', '', t_b.text)
                         bbox = t_b.bbox
                         iou = get_table_iou(col[0][0], col[0][1], col[1][0], col[1][1],
-                                      bbox[0], bbox[1], bbox[2], bbox[3])
+                                            bbox[0], bbox[1], bbox[2], bbox[3])
                         if iou >= 0.3:
                             col_dict['text'] += text
                             obj_in_table.append(t_b)
@@ -98,6 +99,17 @@ def b_table_process(list_line, list_text_boxes, list_cell, table_location):
         return [-8], [-8], [-8]
 
 
+def get_text_box_obj(_text_list, _bbox_list):
+    from format_convert.convert_tree import TextBox
+    _text_box_list = []
+    for i in range(len(_bbox_list)):
+        bbox = _bbox_list[i]
+        b_text = _text_list[i]
+        _text_box_list.append(TextBox([bbox[0][0], bbox[0][1],
+                                       bbox[2][0], bbox[2][1]], b_text))
+    return _text_box_list
+
+
 def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
     log('start')
     # 检测无边框表格
@@ -111,16 +123,18 @@ def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
         log('detect not b_table_list')
         return [], [], []
 
-    if show:
-        for b_table in b_table_list:
-            # for line in b_table:
-            cv2.rectangle(img, (int(b_table[0]), int(b_table[1])), (int(b_table[2]), int(b_table[3])),
-                     (0, 0, 255))
-        cv2.imshow('b_table', img)
-        cv2.waitKey(0)
+    # if show:
+    #     for b_table in b_table_list:
+    #         # for line in b_table:
+    #         cv2.rectangle(img, (int(b_table[0]), int(b_table[1])), (int(b_table[2]), int(b_table[3])),
+    #                  (0, 0, 255), 2)
+    #     cv2.namedWindow('b_table', cv2.WINDOW_NORMAL)
+    #     cv2.imshow('b_table', img)
+    #     cv2.waitKey(0)
 
     if show:
         print('b_table_list', b_table_list)
+        print('table_list', table_list)
 
     # 排除otr结果
     b_table_location_list = []
@@ -137,18 +151,19 @@ def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
             min_x = b_table[0]
         if b_table[2] > max_x:
             max_x = b_table[2]
-        b_loc = [min_x, min_y, max_x, max_y]
+        b_loc = [min_x, min_y, max_x, max_y, b_table[4]]
         inter_flag = False
         for table in table_list:
             loc = table.get('bbox')
             rows = table.get('table')
-            if line_iou([[0, loc[1]], [0, loc[3]]], [[0, b_loc[1]], [0, b_loc[3]]], axis=1) > 0.3:
-                if len(rows) <= 1:
-                    if loc[1] < b_loc[1] < loc[3] < b_loc[3]:
-                        b_loc[1] = loc[3]
-                    if b_loc[1] < loc[1] < b_loc[3] < loc[3]:
-                        b_loc[3] = loc[1]
-                    continue
+            iou = line_iou([[0, loc[1]], [0, loc[3]]], [[0, b_loc[1]], [0, b_loc[3]]], axis=1)
+            if iou > 0.3:
+                # if len(rows) <= 1:
+                #     if loc[1] < b_loc[1] < loc[3] < b_loc[3]:
+                #         b_loc[1] = loc[3]
+                #     if b_loc[1] < loc[1] < b_loc[3] < loc[3]:
+                #         b_loc[3] = loc[1]
+                #     continue
                 inter_flag = True
                 # cv2.rectangle(img, [int(loc[0]), int(loc[1])], [int(loc[2]), int(loc[3])], (0, 0, 255))
                 # cv2.rectangle(img, [int(b_loc[0]), int(b_loc[1])], [int(b_loc[2]), int(b_loc[3])], (0, 0, 255))
@@ -158,14 +173,45 @@ def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
         if not inter_flag:
             b_table_location_list.append(b_loc)
     if not b_table_location_list:
-        log('not b_table_location_list')
+        log('except otr, not b_table_location_list')
         return [], [], []
 
+    if show:
+        print('len(b_table_location_list)', len(b_table_location_list))
+
+    # 排除有重合的,取概率大的
+    if len(b_table_location_list) > 1:
+        temp_list = []
+        used_b_loc = []
+        for i in range(len(b_table_location_list)):
+            b_loc1 = b_table_location_list[i]
+            if b_loc1 in used_b_loc:
+                continue
+            inter_flag = False
+            for j in range(i+1, len(b_table_location_list)):
+                b_loc2 = b_table_location_list[j]
+                iou = line_iou([[0, b_loc1[1]], [0, b_loc1[3]]], [[0, b_loc2[1]], [0, b_loc2[3]]], axis=1)
+                if show:
+                    print('iou2', iou)
+                if iou > 0.3:
+                    inter_flag = True
+                    break
+            if inter_flag:
+                used_b_loc.append(b_loc2)
+                if b_loc1[4] >= b_loc2[4]:
+                    temp_list.append(b_loc1[:4])
+                else:
+                    temp_list.append(b_loc2[:4])
+            else:
+                temp_list.append(b_loc1[:4])
+        b_table_location_list = temp_list
+
     if show:
         for b_loc in b_table_location_list:
             cv2.rectangle(img, (int(b_loc[0]), int(b_loc[1])), (int(b_loc[2]), int(b_loc[3])),
-                          (0, 0, 255))
-        cv2.imshow('b_table no otr', img)
+                          (0, 0, 255), 2)
+        cv2.namedWindow('b_table_no_otr', cv2.WINDOW_NORMAL)
+        cv2.imshow('b_table_no_otr', img)
         cv2.waitKey(0)
 
     table_list = []
@@ -182,15 +228,21 @@ def get_table(img, table_list, text_list, bbox_list, text_box_list, show=0):
 
         # 根据ocr bbox,规则生成表格线
         start_time = time.time()
-        line_list, cell_list, table_location = get_table_by_rule(img, area_text_list, area_bbox_list, b_loc)
+        line_list, cell_list, table_location, bbox_text_dict = get_table_by_rule(img, area_text_list, area_bbox_list, b_loc, show=show)
         if not table_location:
             log('get_table_by_rule not table_location')
             continue
+        # 获取最新的text_list, bbox_list
+        area_text_list, area_bbox_list = [], []
+        for key in bbox_text_dict.keys():
+            area_bbox_list.append(eval(key))
+            area_text_list.append(bbox_text_dict.get(key))
+        b_text_box_list = get_text_box_obj(area_text_list, area_bbox_list)
         log('get_table_by_rule cost: ' + str(time.time()-start_time))
 
         # 根据表格线生成单元格
         start_time = time.time()
-        text_box_list, _table_list, _obj_in_table_list = b_table_process(line_list, text_box_list, cell_list, table_location)
+        b_text_box_list, _table_list, _obj_in_table_list = b_table_process(line_list, b_text_box_list, cell_list, table_location)
         table_list += _table_list
         obj_in_table_list += _obj_in_table_list
         log('b_table_process cost: ' + str(time.time()-start_time))

+ 179 - 78
botr/rules/get_table_by_rules.py

@@ -1,4 +1,5 @@
 import copy
+import re
 import cv2
 import numpy as np
 from botr.rules.table_utils import shrink_bbox, split_bbox, get_table_bbox_list, count_black
@@ -6,19 +7,69 @@ from botr.utils import line_iou
 # from format_convert.utils import log
 
 
-def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0):
+def get_table_by_rule(img, text_list, bbox_list, table_location, show=0):
+    if show:
+        print('get_table_by_rule bbox_list', bbox_list)
+
     if not bbox_list:
-        return [], [], []
+        return [], [], [], {}
+
+    if show:
+        img_show = copy.deepcopy(img)
+        img_result = copy.deepcopy(img)
 
     # 处理bbox,缩小框
     bbox_list = shrink_bbox(img, bbox_list)
 
     # 创建对应dict
     bbox_text_dict = {}
+    temp_list = []
     for i in range(len(text_list)):
-        bbox_text_dict[str(bbox_list[i])] = text_list[i]
+        # 排除text为空的
+        if not text_list[i]:
+            continue
+        if re.sub(' ', '', text_list[i]) == '':
+            continue
 
-    # print('bbox_text_dict', bbox_text_dict)
+        # text中间为空格,其实是两列的
+        match = re.search('[ ]{3,}', text_list[i])
+        if match:
+            # print(text_list[i][match.span()[1]:], re.match('[((]', text_list[i][match.span()[1]:]))
+            text = text_list[i]
+            bbox = bbox_list[i]
+            blank_index = (match.span()[0] + match.span()[1]) / 2
+            chinese_cnt = len(re.findall('[\u4e00-\u9fff()?。,!【】¥《》]', text)) * 1.5
+            char_cnt = len(re.findall('[ .?!,+*&^%$#@~=:;/<>()a-zA-Z0-9{}]', text))
+            # print(text, match.span()[0], match.span()[1], blank_index, chinese_cnt, char_cnt)
+            char_cnt += chinese_cnt
+            char_pixel = abs(bbox[0][0] - bbox[2][0]) / char_cnt
+            index_pixel = char_pixel * blank_index
+            # print(abs(bbox[0][0] - bbox[2][0]), char_cnt, char_pixel, index_pixel)
+            bbox1 = [bbox[0], bbox[1], [bbox[0][0] + index_pixel, bbox[2][1]], bbox[3]]
+            bbox1 = shrink_bbox(img, [bbox1])[0]
+            text1 = text[:match.span()[0]]
+            bbox2 = [[bbox[0][0]+index_pixel, bbox[0][1]], bbox[1], bbox[2], bbox[3]]
+            bbox2 = shrink_bbox(img, [bbox2])[0]
+            text2 = text[match.span()[1]:]
+            if re.sub(' ', '', text1) != '':
+                bbox_text_dict[str(bbox1)] = text1
+                temp_list.append(bbox1)
+            if re.sub(' ', '', text2) != '':
+                bbox_text_dict[str(bbox2)] = text2
+                temp_list.append(bbox2)
+        # 正常的bbox
+        else:
+            bbox_text_dict[str(bbox_list[i])] = text_list[i]
+            temp_list.append(bbox_list[i])
+    bbox_list = temp_list
+
+    if show:
+        print('bbox_text_dict', bbox_text_dict)
+        for bbox in bbox_list:
+            cv2.rectangle(img_show, (int(bbox[0][0]), int(bbox[0][1])),
+                          (int(bbox[2][0]), int(bbox[2][1])), (255, 0, 0), 2)
+        cv2.imshow('bbox_list', img_show)
+        cv2.waitKey(0)
 
     # 根据bbox_list,计算与table_location左上角坐标距离,锁定第一个bbox
     table_left_up_point = [table_location[0], table_location[1]]
@@ -110,12 +161,13 @@ def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0):
         temp_row.sort(key=lambda x: x[0][1])
         first_col_row.append(temp_row)
 
-    # print('len(first_row)', len(first_row))
-    # print('first_row', [bbox_text_dict.get(str(x)) for x in first_row])
-    # print('first_col', [bbox_text_dict.get(str(x)) for x in first_col])
-    # print('len(first_col)', len(first_col))
-    # print('len(first_row_col)', len(first_row_col))
-    # print('len(first_col_row)', len(first_col_row))
+    if show:
+        print('len(first_row)', len(first_row))
+        print('first_row', [bbox_text_dict.get(str(x)) for x in first_row])
+        print('first_col', [bbox_text_dict.get(str(x)) for x in first_col])
+        print('len(first_col)', len(first_col))
+        print('len(first_row_col)', len(first_row_col))
+        print('len(first_col_row)', len(first_col_row))
 
     # 划线 列
     col_line_list = []
@@ -148,9 +200,10 @@ def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0):
         #     row_line_list.append([table_location[0], int((min_h+last_max_h)/2), table_location[2], int((min_h+last_max_h)/2)])
         last_max_h = max_h
 
-    # print('len(col_line_list)', len(col_line_list))
-    # print('col_line_list', col_line_list)
-    # print('len(row_line_list)', len(row_line_list))
+    if show:
+        print('len(col_line_list)', len(col_line_list))
+        print('col_line_list', col_line_list)
+        print('len(row_line_list)', len(row_line_list))
 
     # 判断列线有没有压在黑色像素上,若有则移动
     temp_list = []
@@ -170,69 +223,82 @@ def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0):
         black_cnt2 = count_black(img[line2[1]:line2[3], line2[0]:line2[2]+1, :], threshold=black_threshold)
         # print('col black_cnt1', i, black_cnt1)
         # print('col black_cnt2', i, black_cnt2)
-        if False and black_cnt2 <= max_black_cnt and black_cnt1 <= max_black_cnt:
-            if black_cnt1 >= black_cnt2:
-                temp_list.append(line2)
-            else:
-                temp_list.append(line1)
-        elif False and black_cnt2 <= max_black_cnt:
-            temp_list.append(line2)
-        elif False and black_cnt1 <= max_black_cnt:
-            temp_list.append(line1)
+        # if black_cnt2 <= max_black_cnt and black_cnt1 <= max_black_cnt:
+        #     if black_cnt1 >= black_cnt2:
+        #         temp_list.append(line2)
+        #     else:
+        #         temp_list.append(line1)
+        # elif black_cnt2 <= max_black_cnt:
+        #     temp_list.append(line2)
+        # elif black_cnt1 <= max_black_cnt:
+        #     temp_list.append(line1)
         # 两条线都不符合
+        # else:
+        # 先找出最近的bbox,不能跨bbox
+        min_distance = 100000
+        min_dis_bbox = bbox_list[0]
+        # for bbox in bbox_list:
+        for bbox in first_col_row[0]:
+            if bbox[2][0] < line2[0]:
+                _dis = line2[0] - bbox[2][0]
+                if _dis < min_distance:
+                    min_distance = _dis
+                    min_dis_bbox = bbox
+        # 从右向左移寻找
+        right_left_index_list = []
+        right_left_cnt_list = []
+        find_flag = False
+        for j in range(line2[0], int(min_dis_bbox[2][0]), -1):
+            # 需连续3个像素列满足要求
+            if len(right_left_index_list) == 3:
+                find_flag = True
+                break
+            black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold)
+            # print('col black_cnt', black_cnt)
+            right_left_cnt_list.append(black_cnt)
+            # 直接找到无黑色像素的
+            if black_cnt == 0:
+                right_left_index_list.append(j)
+            else:
+                right_left_index_list = []
+
+        if show:
+            print('find_flag', find_flag)
+
+        if find_flag:
+            temp_list.append([right_left_index_list[1], line2[1], right_left_index_list[1], line2[3]])
         else:
-            # 先找出最近的bbox,不能跨bbox
-            min_distance = 100000
-            min_dis_bbox = bbox_list[0]
-            for bbox in bbox_list:
-                if bbox[2][0] < line2[0]:
-                    _dis = line2[0] - bbox[2][0]
-                    if _dis < min_distance:
-                        min_distance = _dis
-                        min_dis_bbox = bbox
-
-            # 从右向左移寻找
-            right_left_index_list = []
-            right_left_cnt_list = []
-            find_flag = False
-            for j in range(line2[0], int(min_dis_bbox[2][0]), -1):
-                # 需连续3个像素列满足要求
-                if len(right_left_index_list) == 3:
-                    find_flag = True
+            # 为0的找不到,就找最小的
+            # 每个位置加上前后n位求平均
+            n = 1
+            min_cnt = 1000000.
+            min_cnt_index = 0
+            for j, cnt in enumerate(right_left_cnt_list):
+                if show:
+                    print('min_cnt', min_cnt)
+                if j < n or j > len(right_left_cnt_list) - 1 - n:
+                    continue
+                # 小到一定程度提前结束
+                if min_cnt <= 0.001:
                     break
-                black_cnt = count_black(img[line1[1]:line1[3], j:j+1, :], threshold=black_threshold)
-                # print('col black_cnt', black_cnt)
-                right_left_cnt_list.append(black_cnt)
-                # 直接找到无黑色像素的
-                if black_cnt == 0:
-                    right_left_index_list.append(j)
-                else:
-                    right_left_index_list = []
-            if find_flag:
-                temp_list.append([right_left_index_list[1], line2[1], right_left_index_list[1], line2[3]])
-            else:
-                # 为0的找不到,就找最小的
-                # 每个位置加上前后n位求平均
-                n = 1
-                min_cnt = 1000000.
-                min_cnt_index = 0
-                for j, cnt in enumerate(right_left_cnt_list):
-                    if j < n or j > len(right_left_cnt_list) - 1 - n:
-                        continue
-                    # 小到一定程度提前结束
-                    if min_cnt <= 0.001:
-                        break
-                    last_cnt = right_left_cnt_list[j-1]
-                    next_cnt = right_left_cnt_list[j+1]
-                    avg_cnt = (last_cnt + cnt + next_cnt) / 3
-                    if avg_cnt < min_cnt:
-                        min_cnt = avg_cnt
-                        min_cnt_index = j
-
-                min_cnt_index = line2[0] - min_cnt_index
-                temp_list.append([min_cnt_index, line2[1], min_cnt_index, line2[3]])
+                last_cnt = right_left_cnt_list[j-1]
+                next_cnt = right_left_cnt_list[j+1]
+                avg_cnt = (last_cnt + cnt + next_cnt) / 3
+                if avg_cnt < min_cnt:
+                    min_cnt = avg_cnt
+                    min_cnt_index = j
+            min_cnt_index = line2[0] - min_cnt_index
+            temp_list.append([min_cnt_index, line2[1], min_cnt_index, line2[3]])
     col_line_list = temp_list
 
+    if show:
+        print('len(col_line_list)', len(col_line_list))
+        for col in col_line_list:
+            col = [int(x) for x in col]
+            cv2.line(img_show, col[:2], col[2:4], (0, 255, 0), 2)
+        cv2.imshow('col_line_list', img_show)
+        cv2.waitKey(0)
+
     # 根据列的划线对bbox分列
     last_line = [0, 0, 0, 0]
     col_bbox_list = []
@@ -330,6 +396,14 @@ def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0):
         temp_list.append(_line)
     row_line_list = temp_list
 
+    if show:
+        print('len(row_line_list)', len(row_line_list))
+        print('len(col_line_list)', len(col_line_list))
+
+    # 只有一行或一列的直接跳过
+    if len(row_line_list) < 1 or len(col_line_list) < 1:
+        return [], [], [], {}
+
     # 加上表格轮廓线
     threshold = 5
     min_w = max(table_location[0], 0+threshold)
@@ -341,6 +415,32 @@ def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0):
     col_line_list.append([min_w, min_h, min_w, max_h])
     col_line_list.append([max_w, min_h, max_w, max_h])
 
+    # # 行线、列线两两之间没有bbox则合并
+    # col_line_list.sort(key=lambda x: x[0])
+    # temp_list = []
+    # used_bbox_list = []
+    # last_col = col_line_list[0]
+    # for col in col_line_list[1:]:
+    #     find_flag = False
+    #     for bbox in bbox_list:
+    #         if bbox in used_bbox_list:
+    #             continue
+    #         if last_col[0] <= (bbox[0][0] + bbox[2][0]) / 2 <= col[0]:
+    #             print('bbox', bbox, bbox_text_dict.get(str(bbox)))
+    #             used_bbox_list.append(bbox)
+    #             find_flag = True
+    #             break
+    #     print('last_col, col, find_flag', last_col, col, find_flag)
+    #     if not find_flag:
+    #         new_w = int((last_col[0] + col[0])/2)
+    #         temp_list.append([new_w, col[1], new_w, col[3]])
+    #     else:
+    #         temp_list.append(last_col)
+    #     last_col = col
+    # if find_flag:
+    #     temp_list.append(col_line_list[-1])
+    # col_line_list = temp_list
+
     # 由线得到按行列排列的bbox
     row_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in row_line_list]
     col_line_list = [[int(x[0]), int(x[1]), int(x[2]), int(x[3])] for x in col_line_list]
@@ -350,18 +450,19 @@ def get_table_by_rule(img, text_list, bbox_list, table_location, is_test=0):
     line_list = row_line_list + col_line_list
 
     # show
-    if is_test:
+    if show:
         for r in table_cell_list:
             for c in r:
-                cv2.rectangle(img, c[0], c[1], (0, 255, 0), 1)
-        cv2.imshow('table_cell', img)
+                cv2.rectangle(img_result, c[0], c[1], (0, 255, 0), 1)
+        cv2.namedWindow('table_cell', cv2.WINDOW_NORMAL)
+        cv2.imshow('table_cell', img_result)
 
         for line in col_line_list:
-            cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2)
+            cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (0, 0, 255), 2)
         for line in row_line_list:
-            cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2)
+            cv2.line(img_result, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 2)
         cv2.namedWindow('img', cv2.WINDOW_NORMAL)
-        cv2.imshow('img', cv2.resize(img, (768, 1024)))
+        cv2.imshow('img', cv2.resize(img_result, (768, 1024)))
         cv2.waitKey(0)
 
-    return line_list, table_cell_list, table_location
+    return line_list, table_cell_list, table_location, bbox_text_dict

+ 4 - 1
botr/utils.py

@@ -108,7 +108,10 @@ def get_table_iou(x1_min, y1_min, x1_max, y1_max, x2_min, y2_min, x2_max, y2_max
         union_area = min(area1, area2)
 
     # 计算IoU
-    iou = intersection_area / union_area
+    if int(union_area) == 0:
+        iou = 0
+    else:
+        iou = intersection_area / union_area
 
     return iou
 

+ 15 - 1
botr/yolov8/model.py

@@ -1,3 +1,4 @@
+import hashlib
 import logging
 import os
 from copy import deepcopy
@@ -158,7 +159,12 @@ class Predictor:
         """
         same_shapes = all(x.shape == im[0].shape for x in im)
         auto = same_shapes
-        return [LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=x) for x in im]
+        img_list = [LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=x) for x in im]
+        # for img in img_list:
+        #     print('LetterBox img.shape', img.shape)
+        #     cv2.imshow('LetterBox', img)
+        # cv2.waitKey(0)
+        return img_list
 
     def postprocess(self, preds, img, orig_imgs):
         """Postprocesses predictions and returns a list of Results objects."""
@@ -202,9 +208,17 @@ class Predictor:
         for batch in self.dataset:
             path, im0s, vid_cap, s = batch
 
+            # print('im0s', im0s[0].shape)
+            # _md5 = hashlib.md5(im0s[0])
+            # print('md5', _md5.hexdigest())
+            # cv2.imshow('im0s', im0s[0])
+            # cv2.waitKey(0)
+
             # Preprocess
             im = self.preprocess(im0s)
 
+            # print('im', im.shape)
+
             # Inference
             preds = self.model(im)
 

+ 1 - 1
botr/yolov8/predict.py

@@ -22,7 +22,7 @@ def detect(image_np, predictor=None, show=False):
 
 if __name__ == '__main__':
     p = r'C:\Users\Administrator\Desktop\test_b_table\real2.png'
-    paths = glob(r'C:\Users\Administrator\Desktop\test_b_table\*.png')
+    paths = glob(r'C:\Users\Administrator\Desktop\test_b_table\error10.png')
     for p in paths:
         img = cv2.imread(p)
         detect(img, show=True)

+ 27 - 38
botr/yolov8/yolo_interface.py

@@ -71,52 +71,41 @@ def yolo(data, predictor):
 
 def test_yolo_model(from_remote=True):
     _global._init()
-    from format_convert.convert_image import get_best_predict_size, image_process
-    if get_platform() == "Windows":
-        file_path = "C:/Users/Administrator/Desktop/error2.png"
-        file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
-    else:
-        file_path = "1.jpg"
+    file_path = "C:/Users/Administrator/Desktop/test_b_table/yolo_error/error6.png"
+    # file_path = "C:/Users/Administrator/Desktop/test_b_table/error10.png"
+    # file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
+
+    from format_convert.convert_image import get_best_predict_size
+    from format_convert.utils import np2bytes, pil_resize
+
     image_np = cv2.imread(file_path)
-    best_h, best_w = get_best_predict_size(image_np)
-    image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
-    cv2.imwrite(file_path, image_resize)
-    with open(file_path, "rb") as f:
-        file_bytes = f.read()
+    # best_h, best_w = get_best_predict_size(image_np)
+    # image_np = pil_resize(image_np, best_h, best_w)
+
+    file_bytes = np2bytes(image_np)
     file_base64 = base64.b64encode(file_bytes)
     _md5 = get_md5_from_bytes(file_bytes)[0]
 
     _global.update({"port": 15010, "md5": _md5})
 
-    if from_remote:
-        file_json = {"data": file_base64, "is_from_pdf": False, "md5": _md5}
+    file_json = {"data": file_base64, "md5": _md5}
+    # _url = "http://192.168.2.104:18080/yolo"
+    _url = "http://127.0.0.1:18080/yolo"
 
-        # _url = "http://192.168.2.104:18000/otr"
-        _url = "http://127.0.0.1:18000/otr"
-        r = json.loads(request_post(_url, file_json))
-    else:
-        # otr_model = OtrModels().get_model()
-        # r = otr(file_base64, otr_model, is_from_pdf=False)
-        r = image_process(image_resize, file_path)
+    r = json.loads(request_post(_url, file_json))
     print(r)
+    r = r.get('b_table_list')[0]
+    for bbox in r:
+        bbox = [int(x) for x in bbox]
+        cv2.rectangle(image_np, bbox[0:2], bbox[2:4], (0, 0, 255), 2)
+    cv2.namedWindow('img', cv2.WINDOW_NORMAL)
+    cv2.imshow('img', image_np)
+    cv2.waitKey(0)
 
 
 if __name__ == '__main__':
-    if len(sys.argv) == 2:
-        port = int(sys.argv[1])
-    elif len(sys.argv) == 3:
-        port = int(sys.argv[1])
-        using_gpu_index = int(sys.argv[2])
-    else:
-        port = 18080
-        using_gpu_index = 0
-    # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
-    app.run(host='0.0.0.0', port=port)
-    log("YOLO running "+str(port))
-
-    # test_yolo_model(False)
-
-    # print(json.dumps([-2]))
-
-    # otr_model = OtrModels().get_model()
-    # otr("11", otr_model)
+    # port = 18080
+    # using_gpu_index = 0
+    # app.run(host='0.0.0.0', port=port)
+
+    test_yolo_model(True)

+ 22 - 15
format_convert/convert.py

@@ -50,7 +50,7 @@ else:
 
 
 @memory_decorator
-def getText(_type, path_or_stream, time_out=300):
+def getText(_type, path_or_stream, _page_no, time_out=300):
     @timeout(time_out, timeout_exception=TimeoutError, use_signals=False)
     def get_html_1(_class):
         return _class.get_html()
@@ -69,8 +69,8 @@ def getText(_type, path_or_stream, time_out=300):
 
     if _type == "pdf":
         if MAX_COMPUTE:
-            return PDFConvert(path_or_stream, unique_type_dir).get_html()
-        return get_html_1(PDFConvert(path_or_stream, unique_type_dir))
+            return PDFConvert(path_or_stream, unique_type_dir, _page_no).get_html()
+        return get_html_1(PDFConvert(path_or_stream, unique_type_dir, _page_no))
     if _type == "docx":
         if MAX_COMPUTE:
             return DocxConvert(path_or_stream, unique_type_dir).get_html()
@@ -153,7 +153,7 @@ def remove_underline(image_np):
 
 # @timeout_decorator.timeout(100, timeout_exception=TimeoutError)
 # @timeout(globals().get("time_out"), timeout_exception=TimeoutError, use_signals=False)
-def unique_temp_file_process(stream, _type, _md5, time_out=300):
+def unique_temp_file_process(stream, _type, _md5, _page_no, time_out=300):
     if get_platform() == "Windows":
         _global._init()
 
@@ -182,7 +182,7 @@ def unique_temp_file_process(stream, _type, _md5, time_out=300):
         with open(file_path, "wb") as ff:
             ff.write(stream)
 
-        text = getText(_type, file_path, time_out=time_out)
+        text = getText(_type, file_path, _page_no, time_out=time_out)
 
         # 获取swf转换的图片
         swf_images = []
@@ -222,16 +222,19 @@ def unique_temp_file_process(stream, _type, _md5, time_out=300):
 def cut_str(text_list, only_text_list, max_bytes_length=2000000):
     log("into cut_str")
     try:
+        if max_bytes_length and str(max_bytes_length) == '-1':
+            max_bytes_length = 2000000000000
+        else:
+            max_bytes_length = 2000000
+
         # 计算有格式总字节数
         bytes_length = 0
         for text in text_list:
             bytes_length += len(bytes(text, encoding='utf-8'))
 
-        # print("text_list", bytes_length)
-
         # 小于直接返回
         if bytes_length < max_bytes_length:
-            print("return text_list no cut")
+            # print("return text_list no cut")
             return text_list
 
         # 全部文件连接,重新计算无格式字节数
@@ -241,7 +244,6 @@ def cut_str(text_list, only_text_list, max_bytes_length=2000000):
             bytes_length += len(bytes(text, encoding='utf-8'))
             all_text += text
 
-        # print("only_text_list", bytes_length)
         # 小于直接返回
         if bytes_length < max_bytes_length:
             print("return only_text_list no cut")
@@ -249,9 +251,6 @@ def cut_str(text_list, only_text_list, max_bytes_length=2000000):
 
         # 截取字符
         all_text = all_text[:int(max_bytes_length/3)]
-
-        # print("text bytes ", len(bytes(all_text, encoding='utf-8')))
-        # print("return only_text_list has cut")
         return [all_text]
     except Exception as e:
         log("cut_str " + str(e))
@@ -370,6 +369,7 @@ def _convert():
     {[-11], 0}: 新接口idc、isr、atc报错
     {[-12], 0}: 表格跨页连接报错
     {[-13], 0}: pdf表格线处理报错
+    {[-14], 0}: 指定页码报错
     :return: {"result_html": str([]), "result_text":str([]) "is_success": int}
     """
 
@@ -407,13 +407,20 @@ def _convert():
         _md5 = get_md5_from_bytes(stream)
         _md5 = _md5[0]
         _global.update({"md5": _md5})
+        # 指定页码范围
+        _page_no = data.get('page_no')
+        if _type not in ['pdf']:
+            _page_no = None
+
+        # 最终结果截取的最大字节数
+        max_bytes = data.get("max_bytes")
 
         if get_platform() == "Windows":
             # 解除超时装饰器,直接访问原函数
             # origin_unique_temp_file_process = unique_temp_file_process.__wrapped__
             # text, swf_images = origin_unique_temp_file_process(stream, _type)
             try:
-                text, swf_images = unique_temp_file_process(stream, _type, _md5)
+                text, swf_images = unique_temp_file_process(stream, _type, _md5, _page_no)
             except TimeoutError:
                 log("convert time out! 300 sec")
                 text = [-5]
@@ -421,7 +428,7 @@ def _convert():
         else:
             # Linux 通过装饰器设置整个转换超时时间
             try:
-                text, swf_images = unique_temp_file_process(stream, _type, _md5)
+                text, swf_images = unique_temp_file_process(stream, _type, _md5, _page_no)
             except TimeoutError:
                 log("convert time out! 300 sec")
                 text = [-5]
@@ -461,7 +468,7 @@ def _convert():
             classification = [str(classification[0])]
 
         # 判断长度,过长截取
-        text = cut_str(text, only_text)
+        text = cut_str(text, only_text, max_bytes)
         only_text = cut_str(only_text, only_text)
 
         if len(only_text) == 0:

+ 149 - 41
format_convert/convert_image.py

@@ -179,6 +179,23 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
                             [int(point[1][0]*ratio[0]), int(point[1][1]*ratio[1])],
                             [int(point[2][0]*ratio[0]), int(point[2][1]*ratio[1])],
                             [int(point[3][0]*ratio[0]), int(point[3][1]*ratio[1])]]
+
+        # 去除水印字 根据识别是否为矩形框
+        temp_text_list = []
+        temp_bbox_list = []
+        for i in range(len(bbox_list)):
+            bbox = bbox_list[i]
+            text = text_list[i]
+            if len(re.findall('[\u4e00-\u9fa5]', text)) == len(text):
+                if (abs(bbox[0][1] - bbox[1][1]) <= 2 and abs(bbox[2][1] - bbox[3][1]) <= 2) \
+                        or (abs(bbox[0][0] - bbox[3][0]) <= 4 and abs(bbox[2][0] - bbox[1][0]) <= 4):
+                    temp_text_list.append(text)
+                    temp_bbox_list.append(bbox)
+            else:
+                temp_text_list.append(text)
+                temp_bbox_list.append(bbox)
+        text_list = temp_text_list
+        bbox_list = temp_bbox_list
         return text_list, bbox_list
 
     def otr_process(_image_np):
@@ -237,7 +254,7 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
                                      [int(obj.bbox[2]), int(obj.bbox[1])],
                                      []
                                      ])
-                pdf_text_list.append(text)
+                pdf_text_list.append(re.sub('[\n]', '', obj.get_text()))
 
             pdf_text_box_list = get_text_box_obj(pdf_text_list, pdf_box_list)
 
@@ -246,24 +263,28 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
             text_box_list2 = pdf_text_box_list
 
         _text_box_list, _table_list, _obj_in_table_list = get_table(_image_np, table_list2, text_list2, box_list2, text_box_list2)
-        # print('_text_box_list', len(_text_box_list))
-        # print('_obj_in_table_list', len(_obj_in_table_list))
-        # print('text_box_list2', len(text_box_list2))
-        # print('obj_in_table_list2', len(obj_in_table_list2))
 
         # 保存无边框表格文件
         if _table_list:
-            save_b_table(_image_np, text_box_list2, from_pdf)
+            try:
+                save_b_table(_image_np, text_box_list2, from_pdf)
+            except:
+                pass
 
-        text_box_list2 += _text_box_list
-        text_box_list2 = list(set(text_box_list2))
-        table_list2 += _table_list
-        obj_in_table_list2 = obj_in_table_list2.union(_obj_in_table_list)
-        # print('text_box_list2', len(text_box_list2))
-        # print('obj_in_table_list2', len(obj_in_table_list2))
-        return text_box_list2, table_list2, obj_in_table_list2
+        # print('_text_box_list', _text_box_list)
+        # print('_table_list', _table_list)
+        if from_pdf:
+            text_box_list2 = []
+            table_list2 = []
 
-    def table_process(list_line, list_text_boxes):
+        if _table_list and _text_box_list:
+            text_box_list2 += _text_box_list
+            text_box_list2 = list(set(text_box_list2))
+            # table_list2 += _table_list
+            # obj_in_table_list2 = obj_in_table_list2.union(_obj_in_table_list)
+        return text_box_list2, _table_list, _obj_in_table_list
+
+    def table_process(list_line, list_text_boxes, _image_np):
         # 调用现成方法形成表格
         try:
             if list_line:
@@ -273,12 +294,19 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
                     list_lines.append(LTLine(1, (line[0], line[1]), (line[2], line[3])))
 
                 lt = LineTable()
-                tables, obj_in_table, _ = lt.recognize_table(list_text_boxes, list_lines, False)
+                tables, obj_in_table, _, connect_textbox_list = lt.recognize_table(list_text_boxes, list_lines,
+                                                                                   sourceP_LB=False, splited=False,
+                                                                                   from_pdf=is_from_pdf)
+                # 需分割textbox
+                if connect_textbox_list:
+                    list_text_boxes = table_textbox_split(_image_np, connect_textbox_list, list_text_boxes)
+                    # 新的textbox,重新做表格
+                    tables, obj_in_table, _, connect_textbox_list = lt.recognize_table(list_text_boxes, list_lines,
+                                                                                       sourceP_LB=False, splited=True,
+                                                                                       from_pdf=is_from_pdf)
+
                 if not tables:
                     return list_text_boxes, tables, obj_in_table
-
-                # 合并同一行textbox
-                # list_text_boxes = merge_textbox(list_text_boxes, obj_in_table)
                 return list_text_boxes, tables, obj_in_table
             else:
                 return list_text_boxes, [], set()
@@ -304,7 +332,7 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
         if os.path.exists(_path):
             file_list = glob(_path + '/*')
             if file_list:
-                file_index_list = [int(re.split('[/.\\\\-]', x)[-2]) for x in file_list]
+                file_index_list = [int(re.split('[/.\\\\-]', x)[-3]) for x in file_list]
                 file_index_list.sort(key=lambda x: x)
                 index = file_index_list[-1] + 1
             else:
@@ -316,19 +344,92 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
             from format_convert import _global
             _md5 = _global.get("md5")
 
-            _image_path = _path + '/' + str(_md5) + '-' + str(index) + '.png'
+            _image_path = _path + '/' + str(index) + '-' + str(_md5) + '.png'
             cv2.imwrite(_image_path, image_np2)
             log('save b_table image success!')
 
-            if from_pdf:
-                _file_path = _path + '/' + str(_md5) + '-' + str(index) + '.txt'
-                new_text_box_list2 = [str(x) + '\n' for x in text_box_list2]
-                with open(_file_path, 'w') as f:
-                    f.writelines(new_text_box_list2)
-                log('save b_table txt success!')
+            # if from_pdf:
+            #     _file_path = _path + '/' + str(_md5) + '-' + str(index) + '.txt'
+            #     new_text_box_list2 = [str(x) + '\n' for x in text_box_list2]
+            #     with open(_file_path, 'w') as f:
+            #         f.writelines(new_text_box_list2)
+            #     log('save b_table txt success!')
 
         log('save_b_table cost: ' + str(time.time()-_start_time))
 
+    def table_textbox_split(image_np2, connect_textbox_list, textbox_list):
+        """
+        两个单元格里的文本被ocr识别为一个,需分开才能准确放进表格
+
+        :return:
+        """
+        split_bbox_list = []
+        split_text_list = []
+        splited_textbox_list = []
+        for textbox in connect_textbox_list:
+            bbox = textbox.bbox
+            bbox = [[bbox[0], bbox[1]], [], [bbox[2], bbox[3]], []]
+            sub_image_np = image_np2[int(bbox[0][1]):int(bbox[2][1]), int(bbox[0][0]):int(bbox[2][0]), :]
+            split_index_list = []
+            # 从左到右遍历img
+            for i in range(5, sub_image_np.shape[1]-5):
+                # 找表格分割线,这一列都为黑色像素
+                if np.where(sub_image_np[:, i, 0] < 200)[0].size >= sub_image_np.shape[0]:
+                    split_index_list.append(i)
+
+            # 判断两线之间宽度,去重
+            if len(split_index_list) > 1:
+                last_index = split_index_list[0]
+                temp_list = []
+                delete_list = []
+                for index in split_index_list[1:]:
+                    if index in delete_list:
+                        continue
+                    if index - last_index <= 5:
+                        delete_list.append(index)
+                    else:
+                        last_index = index
+                    temp_list.append(last_index)
+                split_index_list = temp_list
+
+            # n条以上分割线,有问题
+            if len(split_index_list) == 0 or len(split_index_list) >= 2:
+                print('len(split_index_list)', len(split_index_list), split_index_list)
+                continue
+            else:
+                # 根据index拆开图片,重新ocr
+                split_index_list.insert(0, 0)
+                print('split_index_list1', split_index_list)
+                for _i, index in enumerate(split_index_list):
+                    if _i == len(split_index_list) - 1:
+                        split_image_np = sub_image_np[:, index:, :]
+                        split_bbox_list.append([[bbox[0][0]+index, bbox[0][1]], [], [bbox[2][0], bbox[2][1]], []])
+                    else:
+                        next_index = split_index_list[_i+1]
+                        split_image_np = sub_image_np[:, index:next_index, :]
+                        split_bbox_list.append([[bbox[0][0]+index, bbox[0][1]], [], [bbox[0][0]+next_index, bbox[2][1]], []])
+
+                    # ocr
+                    split_image_bytes = np2bytes(split_image_np)
+                    text_list2, bbox_list2 = from_ocr_interface(split_image_bytes, is_table=True, only_rec=True)
+                    print('text_list2', text_list2)
+                    print('bbox_list2', split_bbox_list)
+                    if judge_error_code(text_list2):
+                        text2 = ''
+                    else:
+                        text2 = text_list2[0]
+                    split_text_list.append(text2)
+                splited_textbox_list.append(textbox)
+
+        if split_text_list and split_bbox_list:
+            split_textbox_list = get_text_box_obj(split_text_list, split_bbox_list)
+            for tb in splited_textbox_list:
+                if tb in textbox_list:
+                    textbox_list.remove(tb)
+            textbox_list += split_textbox_list
+
+        return textbox_list
+
     log("into image_preprocess")
     try:
         if image_np is None:
@@ -380,7 +481,7 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
                     return text_list
 
                 # 判断ocr识别是否正确
-                if ocr_cant_read(text_list, box_list) and not idc_flag:
+                if ocr_cant_read(text_list, box_list) and not idc_flag and False:
                     # 方向分类
                     image_np = idc_process(image_np)
                     # cv2.imshow("idc_process", image_np)
@@ -408,20 +509,20 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
                 text_box_list = get_text_box_obj(text_list, box_list)
 
                 # 表格生成
-                text_box_list, table_list, obj_in_table_list = table_process(line_list, text_box_list)
+                text_box_list, table_list, obj_in_table_list = table_process(line_list, text_box_list, image_np)
                 if judge_error_code(table_list):
                     return table_list
 
                 # 无边框表格识别
                 start_time = time.time()
-                text_box_list, table_list, obj_in_table_list = botr_process(image_np, table_list,
-                                                                            text_list, box_list,
-                                                                            text_box_list,
-                                                                            obj_in_table_list,
-                                                                            b_table_from_text,
-                                                                            pdf_obj_list,
-                                                                            pdf_layout_size,
-                                                                            )
+                text_box_list, b_table_list, b_obj_in_table_list = botr_process(image_np, table_list,
+                                                                                text_list, box_list,
+                                                                                text_box_list,
+                                                                                obj_in_table_list,
+                                                                                b_table_from_text,
+                                                                                pdf_obj_list,
+                                                                                pdf_layout_size,
+                                                                                )
                 log('botr process cost: ' + str(time.time()-start_time))
 
                 # 合并非表格的同一行TextBox
@@ -430,12 +531,17 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
                 # 对象生成
                 obj_list = []
                 for table in table_list:
-                    obj_list.append(_Table(table["table"], table["bbox"]))
+                    _table = _Table(table["table"], table["bbox"])
+                    obj_list.append(_table)
+                for table in b_table_list:
+                    _table = _Table(table["table"], table["bbox"])
+                    obj_list.append(_table)
+                    _table.y += 10000
                 for text_box in text_box_list:
                     if text_box not in obj_in_table_list:
                         obj_list.append(_Sentence(text_box.get_text(), text_box.bbox))
 
-                # 修正y
+                # 多图修正y
                 if len(image_np_list) > 1:
                     list_y = []
                     for obj in obj_list:
@@ -464,7 +570,7 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
             text_box_list = get_text_box_obj(text_list, box_list)
 
             # 表格生成
-            text_box_list, table_list, obj_in_table_list = table_process(line_list, text_box_list)
+            text_box_list, table_list, obj_in_table_list = table_process(line_list, text_box_list, image_np)
             if judge_error_code(table_list):
                 return table_list
 
@@ -485,8 +591,10 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False,
 
             # 对象生成
             obj_list = []
+            # print('table_list', table_list)
             for table in table_list:
-                obj_list.append(_Table(table["table"], table["bbox"]))
+                _table = _Table(table["table"], table["bbox"])
+                obj_list.append(_table)
             for text_box in text_box_list:
                 if text_box not in obj_in_table_list:
                     obj_list.append(_Sentence(text_box.get_text(), text_box.bbox))
@@ -656,7 +764,7 @@ def image_slice_new(image_np):
 
         # 分割点距离不能太远
         cut_h = zero_h_after.tolist()[0]
-        print("cut_h", cut_h)
+        # print("cut_h", cut_h)
         if abs(h - cut_h) <= max_distance:
             image_list.append(image_origin[last_h:cut_h, :, :])
             last_h = cut_h

+ 17 - 15
format_convert/convert_need_interface.py

@@ -178,7 +178,7 @@ def from_office_interface(src_path, dest_path, target_format, retry_times=1, fro
         return [-1]
 
 
-def from_ocr_interface(image_stream, is_table=False, from_remote=FROM_REMOTE):
+def from_ocr_interface(image_stream, is_table=False, only_rec=False, from_remote=FROM_REMOTE):
     log("into from_ocr_interface")
     try:
         base64_stream = base64.b64encode(image_stream)
@@ -201,16 +201,18 @@ def from_ocr_interface(image_stream, is_table=False, from_remote=FROM_REMOTE):
                         return ip_port
                     _url = ip_port + "/ocr"
                     r = json.loads(request_post(_url, {"data": base64_stream,
-                                                       "md5": _global.get("md5")},
+                                                       "md5": _global.get("md5"),
+                                                       "only_rec": only_rec
+                                                       },
                                                 time_out=60))
                     log("get interface return")
                     if type(r) == list:
                         # 接口连不上换个端口重试
                         if retry_times_1 <= 1:
-                            if is_table:
-                                return r, r
-                            else:
-                                return r
+                            # if is_table:
+                            return r, r
+                            # else:
+                            #     return r
                         else:
                             retry_times_1 -= 1
                             log("retry post ocr_interface... left times " + str(retry_times_1))
@@ -222,17 +224,17 @@ def from_ocr_interface(image_stream, is_table=False, from_remote=FROM_REMOTE):
                 if globals().get("global_ocr_model") is None:
                     print("=========== init ocr model ===========")
                     globals().update({"global_ocr_model": OcrModels().get_model()})
-                r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"))
+                r = ocr(data=base64_stream, ocr_model=globals().get("global_ocr_model"), only_rec=only_rec)
         except TimeoutError:
-            if is_table:
-                return [-5], [-5]
-            else:
-                return [-5]
+            # if is_table:
+            return [-5], [-5]
+            # else:
+            #     return [-5]
         except requests.exceptions.ConnectionError as e:
-            if is_table:
-                return [-2], [-2]
-            else:
-                return [-2]
+            # if is_table:
+            return [-2], [-2]
+            # else:
+            #     return [-2]
 
         _dict = r
         text_list = eval(_dict.get("text"))

+ 97 - 72
format_convert/convert_pdf.py

@@ -657,13 +657,35 @@ def read_pdfplumber(path, laparams):
 
 
 class PDFConvert:
-    def __init__(self, path, unique_type_dir):
+    def __init__(self, path, unique_type_dir, need_page_no):
         self._doc = _Document(path)
         self.path = path
         self.unique_type_dir = unique_type_dir
         if not os.path.exists(self.unique_type_dir):
             os.mkdir(self.unique_type_dir)
 
+        # 指定提取的页码范围
+        self.need_page_no = need_page_no
+        self.start_page_no = None
+        self.end_page_no = None
+        # 默认使用limit_page_cnt控制,前10页后10页
+        if self.need_page_no is None:
+            self.limit_page_cnt = 20
+        else:
+            # 使用start_page_no,end_page_no范围控制,例如2,5
+            ss = self.need_page_no.split(',')
+            if len(ss) != 2:
+                self._doc.error_code = [-14]
+            else:
+                self.start_page_no = int(ss[0])
+                self.end_page_no = int(ss[-1])
+                if self.end_page_no == -1:
+                    self.end_page_no = 1000000
+                self.start_page_no -= 1
+                self.end_page_no -= 1
+                if self.end_page_no <= self.start_page_no or self.start_page_no < 0 or self.end_page_no < -1:
+                    self._doc.error_code = [-14]
+
         self.packages = ["pdfminer", "PyMuPDF", "PyPDF2", "pdfplumber"]
         self.has_init_pdf = [0] * len(self.packages)
 
@@ -716,7 +738,7 @@ class PDFConvert:
             traceback.print_exc()
             self._doc.error_code = [-3]
 
-    def convert(self):
+    def convert(self, limit_page_cnt=20):
         if self.has_init_pdf[0] == 0:
             self.init_package("pdfminer")
         if self._doc.error_code is not None:
@@ -751,18 +773,23 @@ class PDFConvert:
         page_count = len(pages)
         page_no = 0
         for page in pages:
+            # 指定pdf页码
+            if self.start_page_no is not None and self.end_page_no is not None:
+                if page_count < self.end_page_no:
+                    self.end_page_no = page_count
+                if page_no < self.start_page_no or page_no >= self.end_page_no:
+                    page_no += 1
+                    continue
             # 限制pdf页数,只取前后各10页
-            if page_count > 20:
-                if 10 <= page_no < page_count - 10:
+            else:
+                if page_count > limit_page_cnt and int(limit_page_cnt/2) <= page_no < page_count - int(limit_page_cnt/2):
                     page_no += 1
                     continue
 
-            self._page = _Page(page, page_no)
             # 解析单页
+            self._page = _Page(page, page_no)
             self.convert_page(page, page_no)
 
-            # print("+"*30, page.resources)
-
             if self._doc.error_code is None and self._page.error_code is not None:
                 if self._page.error_code[0] in [-4, -3, 0]:
                     page_no += 1
@@ -1164,7 +1191,7 @@ class PDFConvert:
         return lt_line_list
 
     def recognize_text(self, layout, page_no, lt_text_list, lt_line_list):
-        list_tables, filter_objs, _ = self.lt.recognize_table(lt_text_list, lt_line_list)
+        list_tables, filter_objs, _, connect_textbox_list = self.lt.recognize_table(lt_text_list, lt_line_list)
         self._page.in_table_objs = filter_objs
 
         # print("=======text_len:%d:filter_len:%d"%(len(lt_text_list),len(filter_objs)))
@@ -1270,14 +1297,6 @@ class PDFConvert:
         return is_b_table_flag
 
     def convert_page(self, page, page_no):
-        # pdf page.annots为None,不经过get_layout,直接ocr
-        # if page.annots is None:
-        #     lt_image_list = []
-        #     lt_text_list = []
-        #     # 设置只有图片,可跳到ocr
-        #     only_image = 1
-        #     image_count = 1
-        # else:
         layout = self.get_layout(page, page_no)
         if self._doc.error_code is not None:
             return
@@ -1286,75 +1305,78 @@ class PDFConvert:
             return
 
         # 判断该页的对象类型,并存储
-        # only_image = 1
-        # image_count = 0
         lt_text_list = []
         lt_image_list = []
         for x in layout:
             if isinstance(x, (LTTextBoxHorizontal, LTTextBoxVertical)):
-                # only_image = 0
                 lt_text_list.append(x)
             if isinstance(x, LTFigure):
                 for y in x:
                     if isinstance(y, LTImage):
+                        # 小的图忽略
+                        if y.width <= 300 and y.height <= 300:
+                            continue
+                        # 图的width超过layout width,很大可能是水印
+                        if y.width > layout.width + 20:
+                            continue
                         lt_image_list.append(y)
-                        # image_count += 1
         lt_text_list = self.delete_water_mark(lt_text_list, layout.bbox, 15)
         log("convert_pdf page " + str(page_no))
         log("len(lt_image_list), len(lt_text_list) " + str(len(lt_image_list)) + " " + str(len(lt_text_list)))
+        log('layout.width, layout.height' + str(layout.width) + str(layout.height))
 
         # 若只有文本且图片数为0,直接提取文字及表格
         # if only_image == 0 and image_count == 0:
-        if len(lt_image_list) == 0 and len(lt_text_list) > 0:
-            # PDFPlumber
-            if self.has_init_pdf[3] == 0:
-                self.init_package("pdfplumber")
-            if self._doc.error_code is not None:
-                self._doc.error_code = None
-                log("init pdfplumber failed! try pymupdf...")
-                # 调用pdfplumber获取pdf图片报错,则使用pypdf2将pdf转html
-                page_image = self.get_page_image(page_no)
-                if judge_error_code(page_image):
-                    self._page.error_code = page_image
-                else:
-                    _image = _Image(page_image[1], page_image[0])
-                    self._page.add_child(_image)
-                return
-
-            if not self.is_text_legal(lt_text_list, page_no):
-                return
-
-            # 根据text规律,判断该页是否可能有无边框表格
-            start_time = time.time()
-            if self.judge_b_table(lt_text_list):
-                page_image = self.get_page_image(page_no)
-                if judge_error_code(page_image):
-                    self._page.error_code = page_image
-                else:
-                    _image = _Image(page_image[1], page_image[0])
-                    _image.is_from_pdf = True
-                    _image.b_table_from_text = True
-                    _image.b_table_text_obj_list = lt_text_list
-                    _image.b_table_layout_size = (layout.width, layout.height)
-                    self._page.add_child(_image)
-                log('convert_pdf judge_b_table set image cost: ' + str(time.time()-start_time))
-
-            try:
-                lt_line_list = self.get_page_lines(layout, page_no)
-            except:
-                traceback.print_exc()
-                lt_line_list = []
-                self._page.error_code = [-13]
-            try:
-                # lt_line_list = self.get_text_lines(page,page_no)
-                self.recognize_text(layout, page_no, lt_text_list, lt_line_list)
-            except:
-                traceback.print_exc()
-                self._page.error_code = [-8]
+        # if len(lt_image_list) == 0 and len(lt_text_list) > 0:
+        #     # PDFPlumber
+        #     if self.has_init_pdf[3] == 0:
+        #         self.init_package("pdfplumber")
+        #     if self._doc.error_code is not None:
+        #         self._doc.error_code = None
+        #         log("init pdfplumber failed! try pymupdf...")
+        #         # 调用pdfplumber获取pdf图片报错,则使用pypdf2将pdf转html
+        #         page_image = self.get_page_image(page_no)
+        #         if judge_error_code(page_image):
+        #             self._page.error_code = page_image
+        #         else:
+        #             _image = _Image(page_image[1], page_image[0])
+        #             self._page.add_child(_image)
+        #         return
+        #
+        #     if not self.is_text_legal(lt_text_list, page_no):
+        #         return
+        #
+        #     # 根据text规律,判断该页是否可能有无边框表格
+        #     start_time = time.time()
+        #     if self.judge_b_table(lt_text_list):
+        #         page_image = self.get_page_image(page_no)
+        #         if judge_error_code(page_image):
+        #             self._page.error_code = page_image
+        #         else:
+        #             _image = _Image(page_image[1], page_image[0])
+        #             _image.is_from_pdf = True
+        #             _image.b_table_from_text = True
+        #             _image.b_table_text_obj_list = lt_text_list
+        #             _image.b_table_layout_size = (layout.width, layout.height)
+        #             self._page.add_child(_image)
+        #         log('convert_pdf judge_b_table set image cost: ' + str(time.time()-start_time))
+        #
+        #     try:
+        #         lt_line_list = self.get_page_lines(layout, page_no)
+        #     except:
+        #         traceback.print_exc()
+        #         lt_line_list = []
+        #         self._page.error_code = [-13]
+        #     try:
+        #         # lt_line_list = self.get_text_lines(page,page_no)
+        #         self.recognize_text(layout, page_no, lt_text_list, lt_line_list)
+        #     except:
+        #         traceback.print_exc()
+        #         self._page.error_code = [-8]
 
         # 若该页图片数量过多,或无文本,则直接ocr整页识别
         # elif image_count > 3 or only_image == 1:
-        elif len(lt_image_list) > 3 or len(lt_text_list) == 0:
+        if len(lt_image_list) > 3 or len(lt_text_list) == 0:
             page_image = self.get_page_image(page_no)
             if judge_error_code(page_image):
                 self._page.error_code = page_image
@@ -1365,20 +1387,17 @@ class PDFConvert:
 
         # 正常读取该页对象
         else:
-
             # 图表对象
             for image in lt_image_list:
                 try:
-                    # print("pdf2text LTImage size", page_no, image.width, image.height)
+                    print("pdf2text LTImage size", page_no, image.width, image.height)
                     image_stream = image.stream.get_data()
                     # 小的图忽略
                     if image.width <= 300 and image.height <= 300:
                         continue
                     # 查看提取的图片高宽,太大则用pdf输出图进行ocr识别
                     img_test = Image.open(io.BytesIO(image_stream))
-                    # img_test.show()
                     if image.height >= 1000 and image.width >= 1000:
-                        # print("pdf2text LTImage stream output size", img_test.size)
                         page_image = self.get_page_image(page_no)
                         if judge_error_code(page_image):
                             self._page.error_code = page_image
@@ -1415,9 +1434,11 @@ class PDFConvert:
                 else:
                     _image = _Image(page_image[1], page_image[0])
                     _image.is_from_pdf = True
+                    _image.b_table_from_text = True
+                    _image.b_table_text_obj_list = lt_text_list
+                    _image.b_table_layout_size = (layout.width, layout.height)
                     self._page.add_child(_image)
 
-            # lt_line_list = self.get_text_lines(page, page_no)
             try:
                 lt_line_list = self.get_page_lines(layout, page_no)
             except:
@@ -1635,7 +1656,11 @@ class PDFConvert:
         return [html_str]
 
     def get_html(self):
+        if self._doc.error_code is not None:
+            return self._doc.error_code
+
         self.convert()
+
         if self._doc.error_code is not None:
             return self._doc.error_code
         html = self._doc.get_html(return_list=True)

+ 8 - 7
format_convert/convert_test.py

@@ -14,7 +14,7 @@ from format_convert.utils import get_platform, request_post, get_md5_from_bytes
 from format_convert.convert import to_html
 
 
-def test_one(p, from_remote=False):
+def test_one(p, page_no_range=None, from_remote=False):
     start_time = time.time()
     with open(p, "rb") as f:
         file_bytes = f.read()
@@ -22,13 +22,13 @@ def test_one(p, from_remote=False):
 
     _md5 = get_md5_from_bytes(file_bytes)
 
-    data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": 100}
+    data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": 100, 'page_no': page_no_range}
     if from_remote:
-        _url = 'http://121.46.18.113:15010/convert'
+        # _url = 'http://121.46.18.113:15010/convert'
         # _url = 'http://192.168.2.103:15010/convert'
         # _url = 'http://192.168.2.102:15011/convert'
         # _url = 'http://172.16.160.65:15010/convert'
-        # _url = 'http://127.0.0.1:15010/convert'
+        _url = 'http://127.0.0.1:15010/convert'
         result = json.loads(request_post(_url, data, time_out=10000))
         text_str = ""
         for t in result.get("result_html"):
@@ -59,18 +59,19 @@ def test_duplicate(path_list, process_no=None):
 
 if __name__ == '__main__':
     if get_platform() == "Windows":
+        # file_path = "C:/Users/Administrator/Desktop/2.png"
         # file_path = "C:/Users/Administrator/Desktop/test_xls/merge_cell.xlsx"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_Interface/20210609202634853485.xlsx"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_ODPS/1624325845476.pdf"
         # file_path = "C:/Users/Administrator/Downloads/20210508190133924ba.pdf"
         # file_path = "C:/Users/Administrator/Desktop/test_doc/error8.doc"
         # file_path = "C:/Users/Administrator/Desktop/test_image/error10.png"
-        file_path = "C:/Users/Administrator/Desktop/test_b_table/error1.png"
-        # file_path = "C:/Users/Administrator/Desktop/test_pdf/普通error/error3.pdf"
+        # file_path = "C:/Users/Administrator/Desktop/test_b_table/error1.png"
+        file_path = "C:/Users/Administrator/Desktop/test_pdf/error1.pdf"
         # file_path = "C:/save_b_table/0-0895e32470613dd7be1139eefd1342c4.png"
     else:
         file_path = "1660296734009.pdf"
-    test_one(file_path, from_remote=True)
+    test_one(file_path, page_no_range='13,14', from_remote=True)
 
     # paths = glob("C:/Users/Administrator/Desktop/test_image/*")
     # for file_path in paths:

+ 235 - 214
format_convert/utils.py

@@ -38,7 +38,7 @@ if get_platform() == "Linux":
 import math
 
 
-def judge_error_code(_list, code=[0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13]):
+def judge_error_code(_list, code=[0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14]):
     """
     [0] : continue
     [-1]: 逻辑处理错误
@@ -54,6 +54,7 @@ def judge_error_code(_list, code=[0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -1
     [-11]: 新接口idc、isr、atc报错
     [-12]: 表格跨页连接报错
     [-13]: pdf表格线处理报错
+    [-14]: 指定页码报错
     """
     for c in code:
         if isinstance(_list, list) and _list == [c]:
@@ -334,9 +335,11 @@ def slash_replace(_str, reverse=False):
 
 
 class LineTable:
-    def recognize_table(self, list_textbox, list_line, sourceP_LB=True):
+    def recognize_table(self, list_textbox, list_line, sourceP_LB=True, splited=False, from_pdf=False):
         self.list_line = list_line
         self.list_crosspoints = self.recognize_crosspoints(list_line)
+        self.from_pdf = from_pdf
+        self.splited = splited
 
         # 聚类
         cluster_crosspoints = []
@@ -382,73 +385,75 @@ class LineTable:
         list_tables = []
         for l_rect in list_l_rect:
             _ta = self.rect2table(list_textbox, l_rect, in_objs, sourceP_LB=sourceP_LB)
+            if self.connect_bbox_list:
+                return [], [], [], self.connect_bbox_list
             if _ta:
                 list_tables.append(_ta)
         # 展示表格及文字
         # self._plot(list_line, list_textbox)
-        return list_tables, in_objs, list_l_rect
+        return list_tables, in_objs, list_l_rect, []
 
-    def recognize_table_by_rect(self, list_textbox, list_rect, margin=2):
-
-        dump_margin = 5
-        list_rect_tmp = []
-        # 去重
-        for _rect in list_rect:
-            if (_rect.bbox[3] - _rect.bbox[1] < 10) or (abs(_rect.bbox[2] - _rect.bbox[0]) < 5):
-                continue
-            _find = False
-            for _tmp in list_rect_tmp:
-                for i in range(4):
-                    if abs(_rect.bbox[i] - _tmp.bbox[i]) < dump_margin:
-                        pass
-                    else:
-                        _find = False
-                        break
-                    if i == 3:
-                        _find = True
-                if _find:
-                    break
-            if not _find:
-                list_rect_tmp.append(_rect)
-
-        # print("=====",len(list_rect),len(list_rect_tmp))
-        # print(list_rect_tmp)
-        # from matplotlib import pyplot as plt
-        # plt.figure()
-        # for _rect in list_rect_tmp:
-        #     x0,y0,x1,y1 = _rect.bbox
-        #     plt.boxplot(_rect.bbox)
-        # plt.show()
-
-        cluster_rect = []
-        for _rect in list_rect:
-            _find = False
-            for cr in cluster_rect:
-                for cr_rect in cr:
-                    if abs((cr_rect.bbox[2] - cr_rect.bbox[0] + _rect.bbox[2] - _rect.bbox[0]) - (
-                            max(cr_rect.bbox[2], _rect.bbox[2]) - min(cr_rect.bbox[0], _rect.bbox[0]))) < margin:
-                        _find = True
-                        cr.append(_rect)
-                        break
-                    elif abs((cr_rect.bbox[3] - cr_rect.bbox[1] + _rect.bbox[3] - _rect.bbox[1]) - (
-                            max(cr_rect.bbox[3], _rect.bbox[3]) - min(cr_rect.bbox[1], _rect.bbox[1]))) < margin:
-                        _find = True
-                        cr.append(_rect)
-                        break
-                if _find:
-                    break
-            if not _find:
-                cluster_rect.append([_rect])
-
-        list_l_rect = cluster_rect
-
-        in_objs = set()
-        list_tables = []
-        for l_rect in list_l_rect:
-            _ta = self.rect2table(list_textbox, l_rect, in_objs)
-            if _ta:
-                list_tables.append(_ta)
-        return list_tables, in_objs, list_l_rect
+    # def recognize_table_by_rect(self, list_textbox, list_rect, margin=2):
+    #
+    #     dump_margin = 5
+    #     list_rect_tmp = []
+    #     # 去重
+    #     for _rect in list_rect:
+    #         if (_rect.bbox[3] - _rect.bbox[1] < 10) or (abs(_rect.bbox[2] - _rect.bbox[0]) < 5):
+    #             continue
+    #         _find = False
+    #         for _tmp in list_rect_tmp:
+    #             for i in range(4):
+    #                 if abs(_rect.bbox[i] - _tmp.bbox[i]) < dump_margin:
+    #                     pass
+    #                 else:
+    #                     _find = False
+    #                     break
+    #                 if i == 3:
+    #                     _find = True
+    #             if _find:
+    #                 break
+    #         if not _find:
+    #             list_rect_tmp.append(_rect)
+    #
+    #     # print("=====",len(list_rect),len(list_rect_tmp))
+    #     # print(list_rect_tmp)
+    #     # from matplotlib import pyplot as plt
+    #     # plt.figure()
+    #     # for _rect in list_rect_tmp:
+    #     #     x0,y0,x1,y1 = _rect.bbox
+    #     #     plt.boxplot(_rect.bbox)
+    #     # plt.show()
+    #
+    #     cluster_rect = []
+    #     for _rect in list_rect:
+    #         _find = False
+    #         for cr in cluster_rect:
+    #             for cr_rect in cr:
+    #                 if abs((cr_rect.bbox[2] - cr_rect.bbox[0] + _rect.bbox[2] - _rect.bbox[0]) - (
+    #                         max(cr_rect.bbox[2], _rect.bbox[2]) - min(cr_rect.bbox[0], _rect.bbox[0]))) < margin:
+    #                     _find = True
+    #                     cr.append(_rect)
+    #                     break
+    #                 elif abs((cr_rect.bbox[3] - cr_rect.bbox[1] + _rect.bbox[3] - _rect.bbox[1]) - (
+    #                         max(cr_rect.bbox[3], _rect.bbox[3]) - min(cr_rect.bbox[1], _rect.bbox[1]))) < margin:
+    #                     _find = True
+    #                     cr.append(_rect)
+    #                     break
+    #             if _find:
+    #                 break
+    #         if not _find:
+    #             cluster_rect.append([_rect])
+    #
+    #     list_l_rect = cluster_rect
+    #
+    #     in_objs = set()
+    #     list_tables = []
+    #     for l_rect in list_l_rect:
+    #         _ta = self.rect2table(list_textbox, l_rect, in_objs)
+    #         if _ta:
+    #             list_tables.append(_ta)
+    #     return list_tables, in_objs, list_l_rect
 
     def recognize_crosspoints(self, list_line, fixLine=True):
 
@@ -577,43 +582,43 @@ class LineTable:
         # print("points num",len(list_crosspoints))
         return list_crosspoints
 
-    def recognize_rect(self, _page):
-        list_line = []
-        for _obj in _page._objs:
-            if isinstance(_obj, (LTLine)):
-                list_line.append(_obj)
-        list_crosspoints = self.recognize_crosspoints(list_line)
-
-        # 聚类
-        cluster_crosspoints = []
-        for _point in list_crosspoints:
-            cluster_crosspoints.append({"lines": _point.get("lines"), "points": [_point]})
-        while 1:
-            _find = False
-            new_cluster_crosspoints = []
-            for l_point in cluster_crosspoints:
-                _flag = False
-                for l_n_point in new_cluster_crosspoints:
-                    line1 = l_point.get("lines")
-                    line2 = l_n_point.get("lines")
-                    if len(line1 & line2) > 0:
-                        _find = True
-                        _flag = True
-                        l_n_point["lines"] = line1.union(line2)
-                        l_n_point["points"].extend(l_point["points"])
-                if not _flag:
-                    new_cluster_crosspoints.append({"lines": l_point.get("lines"), "points": l_point.get("points")})
-            cluster_crosspoints = new_cluster_crosspoints
-            if not _find:
-                break
-        # print(len(cluster_crosspoints))
-
-        list_l_rect = []
-        for table_crosspoint in cluster_crosspoints:
-            list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
-            list_l_rect.append(list_rect)
-
-        return list_l_rect
+    # def recognize_rect(self, _page):
+    #     list_line = []
+    #     for _obj in _page._objs:
+    #         if isinstance(_obj, (LTLine)):
+    #             list_line.append(_obj)
+    #     list_crosspoints = self.recognize_crosspoints(list_line)
+    #
+    #     # 聚类
+    #     cluster_crosspoints = []
+    #     for _point in list_crosspoints:
+    #         cluster_crosspoints.append({"lines": _point.get("lines"), "points": [_point]})
+    #     while 1:
+    #         _find = False
+    #         new_cluster_crosspoints = []
+    #         for l_point in cluster_crosspoints:
+    #             _flag = False
+    #             for l_n_point in new_cluster_crosspoints:
+    #                 line1 = l_point.get("lines")
+    #                 line2 = l_n_point.get("lines")
+    #                 if len(line1 & line2) > 0:
+    #                     _find = True
+    #                     _flag = True
+    #                     l_n_point["lines"] = line1.union(line2)
+    #                     l_n_point["points"].extend(l_point["points"])
+    #             if not _flag:
+    #                 new_cluster_crosspoints.append({"lines": l_point.get("lines"), "points": l_point.get("points")})
+    #         cluster_crosspoints = new_cluster_crosspoints
+    #         if not _find:
+    #             break
+    #     # print(len(cluster_crosspoints))
+    #
+    #     list_l_rect = []
+    #     for table_crosspoint in cluster_crosspoints:
+    #         list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
+    #         list_l_rect.append(list_rect)
+    #
+    #     return list_l_rect
 
     def crosspoint2rect(self, list_crosspoint, margin=10):
 
@@ -682,6 +687,10 @@ class LineTable:
                 set_bbox.add(_bbox)
         list_rect = tmp_rect
 
+        # _l = [x.get('point') for x in list_crosspoint]
+        # _l.sort(key=lambda x: (x[0], x[1]))
+        # print('list_crosspoint', _l)
+        # print('list_rect', list_rect)
         # import cv2
         # import numpy as np
         # import random
@@ -761,113 +770,113 @@ class LineTable:
                                 "top": top, "buttom": buttom, "lines": set([line1_key, line2_key]),
                                 "p_lines": [line1, line2]}
 
-    def unionTable(self, list_table, fixspan=True, margin=2):
-        set_x = set()
-        set_y = set()
-
-        list_cell = []
-        for _t in list_table:
-            for _line in _t:
-                list_cell.extend(_line)
-
-        clusters_rects = []
-        # 根据y1聚类
-        set_id = set()
-        list_cell_dump = []
-        for _cell in list_cell:
-            _id = id(_cell)
-            if _id in set_id:
-                continue
-            set_id.add(_id)
-            list_cell_dump.append(_cell)
-        list_cell = list_cell_dump
-        list_cell.sort(key=lambda x: x.get("bbox")[3])
-        for _rect in list_cell:
-            _y0 = _rect.get("bbox")[3]
-            _find = False
-            for l_cr in clusters_rects:
-                if abs(l_cr[0].get("bbox")[3] - _y0) < 2:
-                    _find = True
-                    l_cr.append(_rect)
-                    break
-            if not _find:
-                clusters_rects.append([_rect])
-
-        clusters_rects.sort(key=lambda x: x[0].get("bbox")[3], reverse=True)
-        for l_cr in clusters_rects:
-            l_cr.sort(key=lambda x: x.get("bbox")[0])
-
-        # print("=============:")
-        # for l_r in clusters_rects:
-        #     print(len(l_r))
-
-        for _line in clusters_rects:
-            for _rect in _line:
-                (x0, y0, x1, y1) = _rect.get("bbox")
-                set_x.add(x0)
-                set_x.add(x1)
-                set_y.add(y0)
-                set_y.add(y1)
-        if len(set_x) == 0 or len(set_y) == 0:
-            return
-        list_x = list(set_x)
-        list_y = list(set_y)
-
-        list_x.sort(key=lambda x: x)
-        list_y.sort(key=lambda x: x, reverse=True)
-        _table = []
-        line_i = 0
-        for _line in clusters_rects:
-
-            table_line = []
-            cell_i = 0
-            for _rect in _line:
-                (x0, y0, x1, y1) = _rect.get("bbox")
-                _cell = {"bbox": (x0, y0, x1, y1), "rect": _rect.get("rect"),
-                         "rowspan": self.getspan(list_y, y0, y1, margin),
-                         "columnspan": self.getspan(list_x, x0, x1, margin), "text": _rect.get("text", "")}
-                table_line.append(_cell)
-
-                cell_i += 1
-            line_i += 1
-            _table.append(table_line)
-
-        # print("=====================>>")
-        # for _line in _table:
-        #     for _cell in _line:
-        #         print(_cell,end="\t")
-        #     print("\n")
-        # print("=====================>>")
-
-        # print(_table)
-        if fixspan:
-            for _line in _table:
-                extend_line = []
-                for c_i in range(len(_line)):
-                    _cell = _line[c_i]
-                    if _cell.get("columnspan") > 1:
-                        _cospan = _cell.get("columnspan")
-                        _cell["columnspan"] = 1
-                        for i in range(1, _cospan):
-                            extend_line.append({"index": c_i + 1, "cell": _cell})
-                extend_line.sort(key=lambda x: x["index"], reverse=True)
-                for _el in extend_line:
-                    _line.insert(_el["index"], _el["cell"])
-            for l_i in range(len(_table)):
-                _line = _table[l_i]
-                for c_i in range(len(_line)):
-                    _cell = _line[c_i]
-                    if _cell.get("rowspan") > 1:
-                        _rospan = _cell.get("rowspan")
-                        _cell["rowspan"] = 1
-                        for i in range(1, _rospan):
-                            _table[l_i + i].insert(c_i, _cell)
-
-        table_bbox = (_table[0][0].get("bbox")[0], _table[0][0].get("bbox")[1], _table[-1][-1].get("bbox")[2],
-                      _table[-1][-1].get("bbox")[3])
-
-        ta = {"bbox": table_bbox, "table": _table}
-        return ta
+    # def unionTable(self, list_table, fixspan=True, margin=2):
+    #     set_x = set()
+    #     set_y = set()
+    #
+    #     list_cell = []
+    #     for _t in list_table:
+    #         for _line in _t:
+    #             list_cell.extend(_line)
+    #
+    #     clusters_rects = []
+    #     # 根据y1聚类
+    #     set_id = set()
+    #     list_cell_dump = []
+    #     for _cell in list_cell:
+    #         _id = id(_cell)
+    #         if _id in set_id:
+    #             continue
+    #         set_id.add(_id)
+    #         list_cell_dump.append(_cell)
+    #     list_cell = list_cell_dump
+    #     list_cell.sort(key=lambda x: x.get("bbox")[3])
+    #     for _rect in list_cell:
+    #         _y0 = _rect.get("bbox")[3]
+    #         _find = False
+    #         for l_cr in clusters_rects:
+    #             if abs(l_cr[0].get("bbox")[3] - _y0) < 2:
+    #                 _find = True
+    #                 l_cr.append(_rect)
+    #                 break
+    #         if not _find:
+    #             clusters_rects.append([_rect])
+    #
+    #     clusters_rects.sort(key=lambda x: x[0].get("bbox")[3], reverse=True)
+    #     for l_cr in clusters_rects:
+    #         l_cr.sort(key=lambda x: x.get("bbox")[0])
+    #
+    #     # print("=============:")
+    #     # for l_r in clusters_rects:
+    #     #     print(len(l_r))
+    #
+    #     for _line in clusters_rects:
+    #         for _rect in _line:
+    #             (x0, y0, x1, y1) = _rect.get("bbox")
+    #             set_x.add(x0)
+    #             set_x.add(x1)
+    #             set_y.add(y0)
+    #             set_y.add(y1)
+    #     if len(set_x) == 0 or len(set_y) == 0:
+    #         return
+    #     list_x = list(set_x)
+    #     list_y = list(set_y)
+    #
+    #     list_x.sort(key=lambda x: x)
+    #     list_y.sort(key=lambda x: x, reverse=True)
+    #     _table = []
+    #     line_i = 0
+    #     for _line in clusters_rects:
+    #
+    #         table_line = []
+    #         cell_i = 0
+    #         for _rect in _line:
+    #             (x0, y0, x1, y1) = _rect.get("bbox")
+    #             _cell = {"bbox": (x0, y0, x1, y1), "rect": _rect.get("rect"),
+    #                      "rowspan": self.getspan(list_y, y0, y1, margin),
+    #                      "columnspan": self.getspan(list_x, x0, x1, margin), "text": _rect.get("text", "")}
+    #             table_line.append(_cell)
+    #
+    #             cell_i += 1
+    #         line_i += 1
+    #         _table.append(table_line)
+    #
+    #     # print("=====================>>")
+    #     # for _line in _table:
+    #     #     for _cell in _line:
+    #     #         print(_cell,end="\t")
+    #     #     print("\n")
+    #     # print("=====================>>")
+    #
+    #     # print(_table)
+    #     if fixspan:
+    #         for _line in _table:
+    #             extend_line = []
+    #             for c_i in range(len(_line)):
+    #                 _cell = _line[c_i]
+    #                 if _cell.get("columnspan") > 1:
+    #                     _cospan = _cell.get("columnspan")
+    #                     _cell["columnspan"] = 1
+    #                     for i in range(1, _cospan):
+    #                         extend_line.append({"index": c_i + 1, "cell": _cell})
+    #             extend_line.sort(key=lambda x: x["index"], reverse=True)
+    #             for _el in extend_line:
+    #                 _line.insert(_el["index"], _el["cell"])
+    #         for l_i in range(len(_table)):
+    #             _line = _table[l_i]
+    #             for c_i in range(len(_line)):
+    #                 _cell = _line[c_i]
+    #                 if _cell.get("rowspan") > 1:
+    #                     _rospan = _cell.get("rowspan")
+    #                     _cell["rowspan"] = 1
+    #                     for i in range(1, _rospan):
+    #                         _table[l_i + i].insert(c_i, _cell)
+    #
+    #     table_bbox = (_table[0][0].get("bbox")[0], _table[0][0].get("bbox")[1], _table[-1][-1].get("bbox")[2],
+    #                   _table[-1][-1].get("bbox")[3])
+    #
+    #     ta = {"bbox": table_bbox, "table": _table}
+    #     return ta
 
     # 获取点阵
     def getSpanLocation(self, _list, x0, x1, margin):
@@ -1018,6 +1027,7 @@ class LineTable:
             for _cell in table_line:
                 list_cells.append({"cell": _cell, "inbox_textbox_list": []})
 
+        self.connect_bbox_list = []
         for textbox in list_textbox:
             list_iou = []
             for _d in list_cells:
@@ -1030,6 +1040,13 @@ class LineTable:
                 list_cells[max_iou_index]["inbox_textbox_list"].append(textbox)
                 in_objs.add(textbox)
 
+            if not self.from_pdf and not self.splited:
+                # 多个iou大于0.3的,可能是ocr将两个文本合成一个了
+                iou_index_list = np.where(np.array(list_iou) >= 0.3)[0].tolist()
+                if len(iou_index_list) >= 2:
+                    print('len(iou_index_list) >= 2 textbox', textbox)
+                    self.connect_bbox_list.append(textbox)
+
         has_matched_box_list = []
         for _d in list_cells:
             _cell = _d["cell"]
@@ -1222,6 +1239,10 @@ class LineTable:
 
         self.feedText2table(_table, list_textbox, in_objs, sourceP_LB)
 
+        # feedText2table后,有textbox符合多个单元格iou的,可能是文本错误连接了,需拆开
+        if self.connect_bbox_list:
+            return {}
+
         table_bbox = (_table[0][0].get("bbox")[0],
                       _table[0][0].get("bbox")[1],
                       _table[-1][-1].get("bbox")[2],
@@ -1317,7 +1338,7 @@ def combine_object(obj_list, threshold=5):
     from format_convert.convert_tree import _Sentence
     sentence_list = []
     for obj in obj_list:
-        if isinstance(obj, _Sentence):
+        if isinstance(obj, _Sentence) and not obj.is_html:
             obj.content = re.sub("\s", "", obj.content)
             sentence_list.append(obj)
     sentence_list.sort(key=lambda x: (x.y, x.x))
@@ -1378,16 +1399,16 @@ def request_post(url, param, time_out=1000, use_zlib=False):
                 text = result.text
                 break
             else:
-                print('result.status_code', result.status_code)
-                print('result.text', result.text)
+                # print('result.status_code', result.status_code)
+                # print('result.text', result.text)
                 fails += 1
                 continue
         except socket.timeout:
             fails += 1
-            print('timeout! fail times:', fails)
+            # print('timeout! fail times:', fails)
         except:
             fails += 1
-            print('fail! fail times:', fails)
+            # print('fail! fail times:', fails)
             traceback.print_exc()
     return text
 
@@ -1889,7 +1910,7 @@ def remove_red_seal(image_np):
 def pil_resize(image_np, height, width):
     # limit pixels 89478485
     if image_np.shape[0] * image_np.shape[1] * image_np.shape[2] >= 89478485:
-        print("image too large, limit 89478485 pixels", image_np.shape)
+        # print("image too large, limit 89478485 pixels", image_np.shape)
         ratio = image_np.shape[0] / image_np.shape[1]
         if image_np.shape[0] >= image_np.shape[1]:
             image_np = cv2.resize(image_np, (int(3000 / ratio), 3000), interpolation=cv2.INTER_AREA)

+ 133 - 0
isr/remove_pdf_seal.py

@@ -0,0 +1,133 @@
+import base64
+import json
+import os
+import re
+from glob import glob
+import PyPDF2
+import cv2
+import fitz
+import six
+from PIL import Image
+from PyPDF2 import PdfFileMerger
+
+from format_convert.convert_pdf import PDFConvert
+from format_convert.utils import pil2np
+from isr.isr_interface import isr, IsrModels, remove_seal
+from isr.post_process import get_seal_part, replace_seal_part
+
+file_path = r'D:\BIDI_DOC\比地_文档\方案.pdf'
+output_path = '../format_convert/temp/a1/'
+output_pdf_path = r'D:\BIDI_DOC\比地_文档\方案_去印章.pdf'
+
+
+def get_pdf_image():
+    obj = PDFConvert(file_path, output_path)
+    obj.convert(limit_page_cnt=1000)
+
+
+def use_isr():
+    isr_yolo_model, isr_model = IsrModels().get_model()
+
+    paths = glob(output_path + '*')
+    for p in paths:
+        print('p', p)
+        with open(p, 'rb') as f:
+            img_bytes = f.read()
+
+        img_base64 = base64.b64encode(img_bytes)
+        result = isr(img_base64, isr_yolo_model, isr_model)
+        img_new = result.get('image')
+        if isinstance(img_new, list):
+            img_new = cv2.imread(p)
+
+        name = p.split(os.sep)[-1]
+        new_p = output_path + 'new_' + name
+        print('new_p', new_p)
+        cv2.imwrite(new_p, img_new)
+
+
+def base64_to_pil(string):
+    try:
+        # my own train data
+        string = bytes(string, 'utf-8')
+        base64_data = base64.b64decode(string)
+        # with open('temp.jpg', 'wb') as f:
+        #     f.write(base64_data)
+        # print("base64_to_PIL")
+        buf = six.BytesIO()
+        buf.write(base64_data)
+        buf.seek(0)
+        img = Image.open(buf).convert('RGB')
+        return img
+    except Exception as e:
+        print(e)
+        return None
+
+
+def manual_detect_re_isr():
+    isr_yolo_model, isr_model = IsrModels().get_model()
+
+    paths = glob(output_path + '*.json')
+    for p in paths:
+        print(p)
+        with open(p) as f:
+            json_data = json.loads(f.read())
+        shapes = json_data.get('shapes')
+        image_data = json_data.get('imageData')
+        image_pil = base64_to_pil(image_data)
+        image_np = pil2np(image_pil)
+        image_path = output_path + json_data.get('imagePath')
+        lines = []
+        boxes = []
+        for shape in shapes:
+            lines.append(shape['points'])
+            [x0, y0], [x1, y1] = shape['points']
+            box = [[int(x0), int(y0)], [int(x1), int(y1)]]
+            # label = shape['label']
+            boxes.append(box)
+
+        part_list = get_seal_part(image_np, boxes, [])
+        new_part_list = []
+        for part in part_list:
+            part_remove = remove_seal(part, isr_model)
+            new_part_list.append(part_remove)
+        img_replace = replace_seal_part(image_np, new_part_list, boxes)
+        cv2.imwrite(image_path, img_replace)
+
+
+def image_to_pdf():
+    paths = glob(output_path + 'new_*.png')
+    image_index_list = []
+    for p in paths:
+        print(p)
+        name = p.split(os.sep)[-1]
+        index = int(re.search('\d+', name).group())
+        image_index_list.append([p, index])
+
+        # 去除顶部logo
+        image_np = cv2.imread(p)
+        image_np[:145, :, :] = 255
+        cv2.imwrite(p, image_np)
+    image_index_list.sort(key=lambda x: x[1])
+
+    pdf_doc = fitz.open()
+
+    for p, index in image_index_list:
+        print('index', index)
+        # 用fitz读图片
+        img = fitz.open(p)
+        # 将图片转换为图片pdf
+        img_pdf_page = img.convert_to_pdf()
+        # 用fitz读图片pdf
+        img_pdf_page = fitz.open("pdf", img_pdf_page)
+        # 将图片pdf插入到主pdf中
+        pdf_doc.insert_pdf(img_pdf_page, from_page=0, to_page=len(img_pdf_page) - 1)
+        img.close()
+
+    # 将结果写入新的PDF文件
+    pdf_doc.save(output_pdf_path)
+    pdf_doc.close()
+
+
+if __name__ == '__main__':
+    image_to_pdf()

+ 25 - 76
ocr/ocr_interface.py

@@ -35,13 +35,16 @@ def _ocr():
             return json.dumps({"text": str([-9]), "bbox": str([-9])})
         data = request.form.get("data")
         _md5 = request.form.get("md5")
+        only_rec = request.form.get("only_rec")
+        if only_rec is None:
+            only_rec = False
         _global.update({"md5": _md5})
         ocr_model = globals().get("global_ocr_model")
         if ocr_model is None:
             log("----------- init ocr_model ------------")
             ocr_model = OcrModels().get_model()
             globals().update({"global_ocr_model": ocr_model})
-        text = ocr(data, ocr_model)
+        text = ocr(data, ocr_model, only_rec)
         return json.dumps(text)
     except TimeoutError:
         return json.dumps({"text": str([-5]), "bbox": str([-5])})
@@ -52,50 +55,45 @@ def _ocr():
         log("ocr interface finish time " + str(time.time()-start_time))
 
 
-def ocr(data, ocr_model):
+def ocr(data, ocr_model, only_rec=False):
     log("into ocr_interface ocr")
     try:
         img_data = base64.b64decode(data)
-        text = picture2text(img_data, ocr_model)
+        text = picture2text(img_data, ocr_model, only_rec)
         return text
     except TimeoutError:
         return {"text": str([-5]), "bbox": str([-5])}
 
 
-def picture2text(img_data, ocr_model):
+def picture2text(img_data, ocr_model, only_rec=False):
     log("into ocr_interface picture2text")
     try:
         # 二进制数据流转np.ndarray [np.uint8: 8位像素]
         img = bytes2np(img_data)
 
         # 预测
-        results = ocr_model.ocr(img, det=True, rec=True, cls=False)
+        if only_rec:
+            results = ocr_model.ocr(img, det=False, rec=True, cls=False)
+        else:
+            results = ocr_model.ocr(img, det=True, rec=True, cls=False)
 
         # 循环每张图片识别结果
         text_list = []
         bbox_list = []
-        for line in results:
-            text_list.append(line[-1][0])
-            bbox_list.append(line[0])
-
-        # 查看bbox
-        # img = np.zeros((img.shape[1], img.shape[0]), np.uint8)
-        # img.fill(255)
-        # for box in bbox_list:
-        #     print(box)
-        #     cv2.rectangle(img, (int(box[0][0]), int(box[0][1])),
-        #                   (int(box[2][0]), int(box[2][1])), (0, 0, 255), 1)
-        # cv2.imshow("bbox", img)
-        # cv2.waitKey(0)
-
-        # log("ocr model use time: " + str(time.time()-start_time))
+        if only_rec:
+            text_list = [results[0][0]]
+            bbox_list = []
+        else:
+            for line in results:
+                text_list.append(line[-1][0])
+                bbox_list.append(line[0])
         return {"text": str(text_list), "bbox": str(bbox_list)}
 
     except TimeoutError:
         raise TimeoutError
-    except Exception as e:
+    except Exception:
         log("picture2text error!")
-        print("picture2text", traceback.print_exc())
+        traceback.print_exc()
         return {"text": str([]), "bbox": str([])}
 
 
@@ -133,74 +131,25 @@ class OcrModels:
 
 
 def test_ocr_model(from_remote=True):
-    file_path = "C:/Users/Administrator/Desktop/error2.png"
+    file_path = "C:/Users/Administrator/Desktop/2.png"
     with open(file_path, "rb") as f:
         file_bytes = f.read()
     file_base64 = base64.b64encode(file_bytes)
     _md5 = get_md5_from_bytes(file_bytes)[0]
-
+    only_rec = False
     _global._init()
     _global.update({"port": 15010, "md5": _md5})
 
     if from_remote:
-        file_json = {"data": file_base64, "md5": _md5}
+        file_json = {"data": file_base64, "md5": _md5, 'only_rec': only_rec}
         # _url = "http://192.168.2.102:17000/ocr"
         _url = "http://127.0.0.1:17000/ocr"
         print(json.loads(request_post(_url, file_json)))
     else:
         ocr_model = OcrModels().get_model()
-        result = ocr(file_base64, ocr_model)
+        result = ocr(file_base64, ocr_model, only_rec=only_rec)
         print(result)
 
 
 if __name__ == '__main__':
-    # if len(sys.argv) == 2:
-    #     port = int(sys.argv[1])
-    # elif len(sys.argv) == 3:
-    #     port = int(sys.argv[1])
-    #     using_gpu_index = int(sys.argv[2])
-    # else:
-    #     port = 17000
-    #     using_gpu_index = 0
-    # _global._init()
-    # _global.update({"port": str(port)})
-    # globals().update({"port": str(port)})
-    #
-    # # ip = get_intranet_ip()
-    # # logging.basicConfig(level=logging.INFO,
-    # #                     format='%(asctime)s - %(name)s - %(levelname)s - '
-    # #                            + ip + ' - ' + str(port) + ' - %(message)s')
-    #
-    # os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
-    #
-    # # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
-    # app.run()
-    # log("OCR running "+str(port))
-
-    # test_ocr_model(False)
-    #
-    # log("OCR running")
-    file_path = "C:/Users/Administrator/Desktop/test_image/error3.png"
-
-    with open(file_path, "rb") as f:
-        file_bytes = f.read()
-    file_base64 = base64.b64encode(file_bytes)
-
-    ocr_model = OcrModels().get_model()
-    result = ocr(file_base64, ocr_model)
-
-    text_list = eval(result.get("text"))
-    box_list = eval(result.get("bbox"))
-    from format_convert.utils import ocr_cant_read
-    print(ocr_cant_read(text_list, box_list))
-    print(text_list)
-    #
-    # new_list = []
-    # for i in range(len(text_list)):
-    #     new_list.append([text_list[i], box_list[i]])
-    #
-    # # print(new_list[0][1])
-    # new_list.sort(key=lambda x: (x[1][1][0], x[1][0][0]))
-    #
-    # for t in new_list:
-    #     print(t[0])
+    test_ocr_model(False)

+ 2 - 6
otr/otr_interface.py

@@ -41,7 +41,8 @@ from format_convert.utils import request_post, judge_error_code, get_intranet_ip
 from otr.table_line import get_points, get_split_line, get_points_row, \
     get_points_col, \
     delete_close_points, fix_outline, get_bbox, get_outline_point, delete_contain_bbox, points_to_line, \
-    fix_inner, merge_line, fix_corner, delete_outline, table_net, table_line
+    fix_inner, merge_line, fix_corner, delete_outline, table_net
+from otr.table_line_new import table_line
 from format_convert import _global
 
 
@@ -72,10 +73,6 @@ def _otr():
         _global.update({"md5": _md5})
 
         list_lines = otr(data, otr_model, is_from_pdf)
-        # if is_from_pdf:
-        #     list_lines = line_detect(img_data, otr_model, prob=0.2)
-        # else:
-        #     list_lines = line_detect(img_data, otr_model, prob=0.5)
         return json.dumps(list_lines)
     except TimeoutError:
         return json.dumps({"list_line": str([-5])})
@@ -90,7 +87,6 @@ def otr(data, otr_model, is_from_pdf):
     log("into otr_interface otr")
     try:
         img_data = base64.b64decode(data)
-        # points_and_lines = pool.apply(table_detect, (img_data,))
         if is_from_pdf:
             list_lines = line_detect(img_data, otr_model, prob=0.2)
         else:

+ 38 - 1
otr/table_line.py

@@ -958,7 +958,7 @@ def table_postprocess(img_new, pred, prob=0.2, is_test=0):
         return [-1]
 
 
-def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
+def table_line(img, model, size=(512, 1024), prob=0.2, is_test=1):
     log("into table_line, prob is " + str(prob))
     sizew, sizeh = size
     img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)
@@ -1011,6 +1011,10 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
             list_rows.append(line)
     log("divide rows and cols " + str(time.time() - start_time))
 
+    # 两种线都需要存在,否则跳过
+    if not list_cols or not list_rows:
+        return []
+
     # 合并错开线
     start_time = time.time()
     list_rows = merge_line(list_rows, axis=0)
@@ -1087,6 +1091,9 @@ def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
     # mat_plot(list_rows + list_cols, "fix_corner", is_test)
     # log("fix_corner " + str(time.time() - start_time))
 
+    # 验证轮廓的4个交点
+    cross_points = fix_4_points(cross_points)
+
     # 修复内部缺线
     start_time = time.time()
     list_rows, list_cols = fix_inner(list_rows, list_cols, cross_points, split_y)
@@ -1763,6 +1770,9 @@ def fix_inner(row_lines, col_lines, points, split_y):
                             min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
                             min_col_point[i] = point
 
+            if len(min_col_point[0]) == 0:
+                return []
+
             # print("min_col_point", min_col_point)
             # print("min_assist_line", min_assist_line)
             # print("line1_point", line1_point)
@@ -3751,6 +3761,33 @@ def delete_contain_bbox(bboxes):
     return bboxes
 
 
+def fix_4_points(cross_points):
+    cross_points.sort(key=lambda x: (x[0], x[1]))
+    left_up_p = cross_points[0]
+    right_down_p = cross_points[-1]
+    cross_points.sort(key=lambda x: (-x[0], x[1]))
+    right_up_p = cross_points[0]
+    left_down_p = cross_points[-1]
+    print('left_up_p', left_up_p, 'left_down_p', left_down_p)
+    print('right_up_p', right_up_p, 'right_down_p', right_down_p)
+
+    min_x = min(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
+    max_x = max(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
+    min_y = min(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
+    max_y = max(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
+
+    if left_up_p[0] != min_x or left_up_p[1] != min_y:
+        log('轮廓左上角交点有问题')
+    if left_down_p[0] != min_x or left_down_p[1] != max_y:
+        log('轮廓左下角交点有问题')
+        cross_points.append([min_x, max_y])
+    if right_up_p[0] != max_x or right_up_p[1] != min_y:
+        log('轮廓右上角交点有问题')
+    if right_down_p[0] != max_x or right_down_p[1] != max_y:
+        log('轮廓右下角交点有问题')
+    return cross_points
+
+
 if __name__ == '__main__':
     print()
     path = "C:\\Users\\Administrator\\Downloads\\pred.pk"

+ 1057 - 0
otr/table_line_new.py

@@ -0,0 +1,1057 @@
+import copy
+import time
+import traceback
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+from format_convert.utils import log, pil_resize
+
+
+def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
+    log("into table_line, prob is " + str(prob))
+
+    # resize
+    w, h = size
+    img_new = pil_resize(img, h, w)
+    img_show = copy.deepcopy(img_new)
+
+    # predict
+    start_time = time.time()
+    pred = model.predict(np.array([img_new]))
+    pred = pred[0]
+    log("otr model predict time " + str(time.time() - start_time))
+
+    # show
+    show(pred, title='pred', prob=prob, mode=1, is_test=is_test)
+
+    # 根据点获取线
+    start_time = time.time()
+    line_list = points2lines(pred, False, prob=prob)
+    log("points2lines " + str(time.time() - start_time))
+    if not line_list:
+        return []
+    show(line_list, title="points2lines", mode=2, is_test=is_test)
+
+    # 清除短线
+    start_time = time.time()
+    line_list = delete_short_lines(line_list, img_new.shape)
+    show(line_list, title="delete_short_lines", mode=2, is_test=is_test)
+    log("delete_short_lines " + str(time.time() - start_time))
+
+    # 分成横竖线
+    start_time = time.time()
+    row_line_list = []
+    col_line_list = []
+    for line in line_list:
+        if line[0] == line[2]:
+            col_line_list.append(line)
+        elif line[1] == line[3]:
+            row_line_list.append(line)
+    log("divide rows and cols " + str(time.time() - start_time))
+
+    # 两种线都需要存在,否则跳过
+    if not row_line_list or not col_line_list:
+        return []
+
+    # 合并错开线
+    start_time = time.time()
+    row_line_list = merge_line(row_line_list, axis=0)
+    col_line_list = merge_line(col_line_list, axis=1)
+    show(row_line_list + col_line_list, title="merge_line", mode=2, is_test=is_test)
+    log("merge_line " + str(time.time() - start_time))
+
+    # 计算交点
+    cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
+    if not cross_points:
+        return []
+    # 删除无交点线 需重复两次才删的干净
+    row_line_list, col_line_list = delete_single_lines(row_line_list, col_line_list, cross_points)
+    cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
+    row_line_list, col_line_list = delete_single_lines(row_line_list, col_line_list, cross_points)
+    if not row_line_list or not col_line_list:
+        return []
+
+    # 多个表格分割线,获取多个表格区域
+    start_time = time.time()
+    split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
+    area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
+    log("get_split_area " + str(time.time() - start_time))
+
+    # 根据区域循环
+    need_split_flag = False
+    for i in range(len(area_point_list)):
+        sub_row_line_list = area_row_line_list[i]
+        sub_col_line_list = area_col_line_list[i]
+        sub_point_list = area_point_list[i]
+
+        # 修复边框
+        start_time = time.time()
+        new_rows, new_cols, long_rows, long_cols = fix_outline(img_new,
+                                                               sub_row_line_list,
+                                                               sub_col_line_list,
+                                                               sub_point_list)
+
+        # 如有补线
+        if new_rows or new_cols:
+            # 连接至补线的延长线
+            if long_rows:
+                sub_row_line_list = long_rows
+            if long_cols:
+                sub_col_line_list = long_cols
+            # 新的补线
+            if new_rows:
+                sub_row_line_list += new_rows
+            if new_cols:
+                sub_col_line_list += new_cols
+            need_split_flag = True
+            area_row_line_list[i] = sub_row_line_list
+            area_col_line_list[i] = sub_col_line_list
+
+    row_line_list = [y for x in area_row_line_list for y in x]
+    col_line_list = [y for x in area_col_line_list for y in x]
+
+    if need_split_flag:
+        # 修复边框后重新计算交点
+        cross_points = get_points(row_line_list, col_line_list, (img_new.shape[0], img_new.shape[1]))
+        split_lines, split_y = get_split_line(cross_points, col_line_list, img_new)
+        area_row_line_list, area_col_line_list, area_point_list = get_split_area(split_y, row_line_list, col_line_list, cross_points)
+
+    show(cross_points, title="get_points", img=img_show, mode=4, is_test=is_test)
+    show(split_lines, title="split_lines", img=img_show, mode=3, is_test=is_test)
+    show(row_line_list + col_line_list, title="fix_outline", mode=2, is_test=is_test)
+    log("fix_outline " + str(time.time() - start_time))
+
+    # 根据区域循环
+    for i in range(len(area_point_list)):
+        sub_row_line_list = area_row_line_list[i]
+        sub_col_line_list = area_col_line_list[i]
+        sub_point_list = area_point_list[i]
+
+        # 验证轮廓的4个交点
+        sub_row_line_list, sub_col_line_list = fix_4_points(sub_point_list, sub_row_line_list, sub_col_line_list)
+
+        # 把四个边线在加一次
+        sub_point_list = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
+        sub_row_line_list, sub_col_line_list = add_outline(sub_point_list, sub_row_line_list, sub_col_line_list)
+
+        # 修复内部缺线
+        start_time = time.time()
+        sub_row_line_list, sub_col_line_list = fix_inner(sub_row_line_list, sub_col_line_list, sub_point_list)
+        log("fix_inner " + str(time.time() - start_time))
+        show(sub_row_line_list + sub_col_line_list, title="fix_inner1", mode=2, is_test=is_test)
+
+        # 合并错开
+        start_time = time.time()
+        sub_row_line_list = merge_line(sub_row_line_list, axis=0)
+        sub_col_line_list = merge_line(sub_col_line_list, axis=1)
+        log("merge_line " + str(time.time() - start_time))
+        show(sub_row_line_list + sub_col_line_list, title="merge_line", mode=2, is_test=is_test)
+
+        # 修复内部线后重新计算交点
+        start_time = time.time()
+        cross_points = get_points(sub_row_line_list, sub_col_line_list, (img_new.shape[0], img_new.shape[1]))
+        show(cross_points, title="get_points3", img=img_show, mode=4, is_test=is_test)
+
+        # 消除线突出,获取标准的线
+        area_row_line_list[i], area_col_line_list[i] = get_standard_lines(sub_row_line_list, sub_col_line_list)
+        show(area_row_line_list[i] + area_col_line_list[i], title="get_standard_lines", mode=2, is_test=is_test)
+
+    row_line_list = [y for x in area_row_line_list for y in x]
+    col_line_list = [y for x in area_col_line_list for y in x]
+
+    line_list = row_line_list + col_line_list
+    # 打印处理后线
+    show(line_list, title="all", img=img_show, mode=5, is_test=is_test)
+    log("otr postprocess table_line " + str(time.time() - start_time))
+    return line_list
+
+
+def show(pred_or_lines, title='', prob=0.2, img=None, mode=1, is_test=0):
+    if not is_test:
+        return
+
+    if mode == 1:
+        plt.figure()
+        plt.title(title)
+
+        _array = []
+        for _h in range(len(pred_or_lines)):
+            _line = []
+            for _w in range(len(pred_or_lines[_h])):
+                _prob = pred_or_lines[_h][_w]
+                if _prob[0] > prob:
+                    _line.append((0, 0, 255))
+                elif _prob[1] > prob:
+                    _line.append((255, 0, 0))
+                else:
+                    _line.append((255, 255, 255))
+            _array.append(_line)
+        # plt.axis('off')
+        plt.imshow(np.array(_array))
+        plt.show()
+
+    elif mode == 2:
+        plt.figure()
+        plt.title(title)
+        for _line in pred_or_lines:
+            x0, y0, x1, y1 = _line
+            plt.plot([x0, x1], [y0, y1])
+        plt.show()
+
+    elif mode == 3:
+        for _line in pred_or_lines:
+            x0, y0 = _line[0]
+            x1, y1 = _line[1]
+            cv2.line(img, [int(x0), int(y0)], [int(x1), int(y1)], (0, 0, 255), 2)
+        cv2.namedWindow(title, cv2.WINDOW_NORMAL)
+        cv2.imshow(title, img)
+        cv2.waitKey(0)
+
+    elif mode == 4:
+        for point in pred_or_lines:
+            point = [int(x) for x in point]
+            cv2.circle(img, (point[0], point[1]), 1, (0, 255, 0), 2)
+        cv2.namedWindow(title, cv2.WINDOW_NORMAL)
+        cv2.imshow(title, img)
+        cv2.waitKey(0)
+
+    elif mode == 5:
+        for _line in pred_or_lines:
+            x0, y0, x1, y1 = _line
+            cv2.line(img, [int(x0), int(y0)], [int(x1), int(y1)], (0, 255, 0), 2)
+        cv2.namedWindow(title, cv2.WINDOW_NORMAL)
+        cv2.imshow(title, img)
+        cv2.waitKey(0)
+
+
+def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=8, padding=3, min_len=10,
+                 cell_width=13):
+    _time = time.time()
+
+    log("starting points2lines")
+    height = len(pred)
+    width = len(pred[0])
+
+    _sum = list(np.sum(np.array((pred[..., 0] > prob)).astype(int), axis=1))
+
+    h_index = -1
+    h_lines = []
+    v_lines = []
+    _step = line_width
+    while 1:
+        h_index += 1
+        if h_index >= height:
+            break
+        w_index = -1
+        if sourceP_LB:
+            h_i = height - 1 - h_index
+        else:
+            h_i = h_index
+        _start = None
+        if _sum[h_index] < min_len:
+            continue
+        last_back = 0
+        while 1:
+            if w_index >= width:
+                if _start is not None:
+                    _end = w_index - 1
+                    _bbox = [_start, h_i, _end, h_i]
+                    _dict = {"bbox": _bbox}
+                    h_lines.append(_dict)
+                    _start = None
+                break
+            _h, _v = pred[h_i][w_index]
+            if _h > prob:
+                if _start is None:
+                    _start = w_index
+                w_index += _step
+            else:
+                if _start is not None:
+                    _end = w_index - 1
+                    _bbox = [_start, h_i, _end, h_i]
+                    _dict = {"bbox": _bbox}
+                    h_lines.append(_dict)
+                    _start = None
+
+                w_index -= _step // 2
+                if w_index <= last_back:
+                    w_index = last_back + _step // 2
+                last_back = w_index
+
+    log("starting points2lines 1")
+    w_index = -1
+
+    _sum = list(np.sum(np.array((pred[..., 1] > prob)).astype(int), axis=0))
+    _step = line_width
+    while 1:
+        w_index += 1
+        if w_index >= width:
+            break
+        if _sum[w_index] < min_len:
+            continue
+        h_index = -1
+        _start = None
+        last_back = 0
+        list_test = []
+        list_lineprob = []
+        while 1:
+            if h_index >= height:
+                if _start is not None:
+                    _end = last_h
+                    _bbox = [w_index, _start, w_index, _end]
+                    _dict = {"bbox": _bbox}
+                    v_lines.append(_dict)
+                    _start = None
+                    list_test.append(_dict)
+                break
+            if sourceP_LB:
+                h_i = height - 1 - h_index
+            else:
+                h_i = h_index
+
+            _h, _v = pred[h_index][w_index]
+            list_lineprob.append((h_index, _v))
+            if _v > prob:
+                if _start is None:
+                    _start = h_i
+                h_index += _step
+            else:
+                if _start is not None:
+                    _end = last_h
+                    _bbox = [w_index, _start, w_index, _end]
+                    _dict = {"bbox": _bbox}
+                    v_lines.append(_dict)
+                    _start = None
+                    list_test.append(_dict)
+
+                h_index -= _step // 2
+                if h_index <= last_back:
+                    h_index = last_back + _step // 2
+                last_back = h_index
+
+            last_h = h_i
+    log("starting points2lines 2")
+
+    for _line in h_lines:
+        _bbox = _line["bbox"]
+        _bbox = [max(_bbox[0] - 2, 0), (_bbox[1] + _bbox[3]) / 2, _bbox[2] + 2, (_bbox[1] + _bbox[3]) / 2]
+        _line["bbox"] = _bbox
+
+    for _line in v_lines:
+        _bbox = _line["bbox"]
+        _bbox = [(_bbox[0] + _bbox[2]) / 2, max(_bbox[1] - 2, 0), (_bbox[0] + _bbox[2]) / 2, _bbox[3] + 2]
+        _line["bbox"] = _bbox
+
+    h_lines = lines_cluster(h_lines, line_width=line_width)
+    v_lines = lines_cluster(v_lines, line_width=line_width)
+
+    list_line = []
+    for _line in h_lines:
+        _bbox = _line["bbox"]
+        _bbox = [max(_bbox[0] - 1, 0), (_bbox[1] + _bbox[3]) / 2, _bbox[2] + 1, (_bbox[1] + _bbox[3]) / 2]
+        list_line.append(_bbox)
+    for _line in v_lines:
+        _bbox = _line["bbox"]
+        _bbox = [(_bbox[0] + _bbox[2]) / 2, max(_bbox[1] - 1, 0), (_bbox[0] + _bbox[2]) / 2, _bbox[3] + 1]
+        list_line.append(_bbox)
+
+    log("points2lines cost %.2fs" % (time.time() - _time))
+
+    # import matplotlib.pyplot as plt
+    # plt.figure()
+    # for _line in list_line:
+    #     x0,y0,x1,y1 = _line
+    #     plt.plot([x0,x1],[y0,y1])
+    # for _line in list_line:
+    #     x0,y0,x1,y1 = _line.bbox
+    #     plt.plot([x0,x1],[y0,y1])
+    # for point in list_crosspoints:
+    #     plt.scatter(point.get("point")[0],point.get("point")[1])
+    # plt.show()
+
+    return list_line
+
+
+def lines_cluster(list_lines, line_width):
+    after_len = 0
+    prelength = len(list_lines)
+    append_width = line_width // 2
+    while 1:
+        c_lines = []
+        first_len = after_len
+
+        for _line in list_lines:
+            bbox = _line["bbox"]
+            _find = False
+            for c_l_i in range(len(c_lines)):
+                c_l = c_lines[len(c_lines) - c_l_i - 1]
+                bbox1 = c_l["bbox"]
+                bboxa = [max(0, bbox[0] - append_width), max(0, bbox[1] - append_width), bbox[2] + append_width,
+                         bbox[3] + append_width]
+                bboxb = [max(0, bbox1[0] - append_width), max(0, bbox1[1] - append_width), bbox1[2] + append_width,
+                         bbox1[3] + append_width]
+
+                _iou = getIOU(bboxa, bboxb)
+                if _iou > 0:
+                    new_bbox = [min(bbox[0], bbox[2], bbox1[0], bbox1[2]), min(bbox[1], bbox[3], bbox1[1], bbox1[3]),
+                                max(bbox[0], bbox[2], bbox1[0], bbox1[2]), max(bbox[1], bbox[3], bbox1[1], bbox1[3])]
+                    _find = True
+                    c_l["bbox"] = new_bbox
+                    break
+            if not _find:
+                c_lines.append(_line)
+        after_len = len(c_lines)
+        if first_len == after_len:
+            break
+        list_lines = c_lines
+    log("cluster lines from %d to %d" % (prelength, len(list_lines)))
+    return c_lines
+
+
+def getIOU(bbox0, bbox1):
+    width = abs(max(bbox0[2], bbox1[2]) - min(bbox0[0], bbox1[0])) - (
+            abs(bbox0[2] - bbox0[0]) + abs(bbox1[2] - bbox1[0]))
+    height = abs(max(bbox0[3], bbox1[3]) - min(bbox0[1], bbox1[1])) - (
+            abs(bbox0[3] - bbox0[1]) + abs(bbox1[3] - bbox1[1]))
+    if width <= 0 and height <= 0:
+        iou = abs(width * height / min(abs((bbox0[2] - bbox0[0]) * (bbox0[3] - bbox0[1])),
+                                       abs((bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]))))
+        # print("getIOU", iou)
+        return iou + 0.1
+    return 0
+
+
+def delete_short_lines(list_lines, image_shape, scale=100):
+    # 排除太短的线
+    x_min_len = max(5, int(image_shape[0] / scale))
+    y_min_len = max(5, int(image_shape[1] / scale))
+    new_list_lines = []
+    for line in list_lines:
+        if line[0] == line[2]:
+            if abs(line[3] - line[1]) >= y_min_len:
+                # print("y_min_len", abs(line[3] - line[1]), y_min_len)
+                new_list_lines.append(line)
+        else:
+            if abs(line[2] - line[0]) >= x_min_len:
+                # print("x_min_len", abs(line[2] - line[0]), x_min_len)
+                new_list_lines.append(line)
+    return new_list_lines
+
+
+def delete_single_lines(row_line_list, col_line_list, point_list):
+    new_col_line_list = []
+    min_point_cnt = 2
+    for line in col_line_list:
+        p_cnt = 0
+        for p in point_list:
+            # if line[0] == p[0] and line[1] <= p[1] <= line[3]:
+            if line[0] == p[0]:
+                p_cnt += 1
+            if p_cnt >= min_point_cnt:
+                new_col_line_list.append(line)
+                break
+    new_row_line_list = []
+    for line in row_line_list:
+        p_cnt = 0
+        for p in point_list:
+            # if line[1] == p[1] and line[0] <= p[0] <= line[2]:
+            if line[1] == p[1]:
+                p_cnt += 1
+            if p_cnt >= min_point_cnt:
+                new_row_line_list.append(line)
+                break
+    return new_row_line_list, new_col_line_list
+
+
+def merge_line(lines, axis, threshold=5):
+    """
+    解决模型预测一条直线错开成多条直线,合并成一条直线
+
+    :param lines: 线条列表
+    :param axis: 0:横线 1:竖线
+    :param threshold: 两条线间像素差阈值
+    :return: 合并后的线条列表
+    """
+    # 任意一条line获取该合并的line,横线往下找,竖线往右找
+    lines.sort(key=lambda x: (x[axis], x[1 - axis]))
+    merged_lines = []
+    used_lines = []
+    for line1 in lines:
+        if line1 in used_lines:
+            continue
+
+        merged_line = [line1]
+        used_lines.append(line1)
+        for line2 in lines:
+            if line2 in used_lines:
+                continue
+
+            if line1[1 - axis] - threshold <= line2[1 - axis] <= line1[1 - axis] + threshold:
+                # 计算基准长度
+                min_axis = 10000
+                max_axis = 0
+                for line3 in merged_line:
+                    if line3[axis] < min_axis:
+                        min_axis = line3[axis]
+                    if line3[axis + 2] > max_axis:
+                        max_axis = line3[axis + 2]
+                # 判断两条线有无交集
+                if min_axis <= line2[axis] <= max_axis \
+                        or min_axis <= line2[axis + 2] <= max_axis:
+                    merged_line.append(line2)
+                    used_lines.append(line2)
+        if merged_line:
+            merged_lines.append(merged_line)
+
+    # 合并line
+    result_lines = []
+    for merged_line in merged_lines:
+        # 获取line宽的平均值
+        axis_average = 0
+        for line in merged_line:
+            axis_average += line[1 - axis]
+        axis_average = int(axis_average / len(merged_line))
+
+        # 获取最长line两端
+        merged_line.sort(key=lambda x: (x[axis]))
+        axis_start = merged_line[0][axis]
+        merged_line.sort(key=lambda x: (x[axis + 2]))
+        axis_end = merged_line[-1][axis + 2]
+
+        if axis:
+            result_lines.append([axis_average, axis_start, axis_average, axis_end])
+        else:
+            result_lines.append([axis_start, axis_average, axis_end, axis_average])
+    return result_lines
+
+
+def get_points(row_lines, col_lines, image_size):
+    # 创建空图
+    row_img = np.zeros(image_size, np.uint8)
+    col_img = np.zeros(image_size, np.uint8)
+
+    # 画线
+    threshold = 5
+    for row in row_lines:
+        cv2.line(row_img, (int(row[0] - threshold), int(row[1])), (int(row[2] + threshold), int(row[3])), (255, 255, 255), 1)
+    for col in col_lines:
+        cv2.line(col_img, (int(col[0]), int(col[1] - threshold)), (int(col[2]), int(col[3] + threshold)), (255, 255, 255), 1)
+
+    # 求出交点
+    point_img = np.bitwise_and(row_img, col_img)
+    # cv2.imwrite("get_points.jpg", row_img+col_img)
+    # cv2.imshow("get_points", row_img+col_img)
+    # cv2.waitKey(0)
+
+    # 识别黑白图中的白色交叉点,将横纵坐标取出
+    ys, xs = np.where(point_img > 0)
+    points = []
+    for i in range(len(xs)):
+        points.append((xs[i], ys[i]))
+    points.sort(key=lambda x: (x[0], x[1]))
+    return points
+
+
+def fix_outline(image, row_line_list, col_line_list, point_list, scale=25):
+    log("into fix_outline")
+    x_min_len = max(10, int(image.shape[0] / scale))
+    y_min_len = max(10, int(image.shape[1] / scale))
+
+    if len(row_line_list) <= 1 or len(col_line_list) <= 1:
+        return [], [], row_line_list, col_line_list
+
+    # 预测线取上下左右4个边(会有超出表格部分) [(), ()]
+    row_line_list.sort(key=lambda x: (x[1], x[0]))
+    up_line = row_line_list[0]
+    bottom_line = row_line_list[-1]
+    col_line_list.sort(key=lambda x: x[0])
+    left_line = col_line_list[0]
+    right_line = col_line_list[-1]
+
+    # 计算单格高度宽度
+    if len(row_line_list) > 1:
+        height_dict = {}
+        for j in range(len(row_line_list)):
+            if j + 1 > len(row_line_list) - 1:
+                break
+            height = abs(int(row_line_list[j][3] - row_line_list[j + 1][3]))
+            if height >= 10:
+                if height in height_dict.keys():
+                    height_dict[height] = height_dict[height] + 1
+                else:
+                    height_dict[height] = 1
+        height_list = [[x, height_dict[x]] for x in height_dict.keys()]
+        if height_list:
+            height_list.sort(key=lambda x: (x[1], -x[0]), reverse=True)
+            # print("box_height", height_list)
+            box_height = height_list[0][0]
+        else:
+            box_height = y_min_len
+    else:
+        box_height = y_min_len
+    if len(col_line_list) > 1:
+        box_width = abs(col_line_list[1][2] - col_line_list[0][2])
+    else:
+        box_width = x_min_len
+
+    # 设置轮廓线需超出阈值
+    if box_height >= 2 * y_min_len:
+        fix_h_len = y_min_len
+    else:
+        fix_h_len = box_height * 2 / 3
+    if box_width >= 2 * x_min_len:
+        fix_w_len = x_min_len
+    else:
+        fix_w_len = box_width * 2 / 3
+
+    # 判断超出部分的长度,超出一定长度就补线
+    new_row_lines = []
+    new_col_lines = []
+    all_longer_row_lines = []
+    all_longer_col_lines = []
+
+    # print('box_height, box_width, fix_h_len, fix_w_len', box_height, box_width, fix_h_len, fix_w_len)
+    # print('bottom_line, left_line, right_line', bottom_line, left_line, right_line)
+
+    # 补左右两条竖线超出来的线的row
+    if up_line[1] - left_line[1] >= fix_h_len and up_line[1] - right_line[1] >= fix_h_len:
+        if up_line[1] - left_line[1] >= up_line[1] - right_line[1]:
+            new_row_lines.append([left_line[0], left_line[1], right_line[0], left_line[1]])
+            new_col_y = left_line[1]
+            # 补了row,要将其他短的col连到row上
+            for j in range(len(col_line_list)):
+                col = col_line_list[j]
+                if abs(new_col_y - col[1]) <= box_height:
+                    col_line_list[j][1] = min([new_col_y, col[1]])
+        else:
+            new_row_lines.append([left_line[0], right_line[1], right_line[0], right_line[1]])
+            new_col_y = right_line[1]
+            # 补了row,要将其他短的col连到row上
+            for j in range(len(col_line_list)):
+                col = col_line_list[j]
+                # 且距离不能相差太大
+                if abs(new_col_y - col[1]) <= box_height:
+                    col_line_list[j][1] = min([new_col_y, col[1]])
+    if left_line[3] - bottom_line[3] >= fix_h_len and right_line[3] - bottom_line[3] >= fix_h_len:
+        if left_line[3] - bottom_line[3] >= right_line[3] - bottom_line[3]:
+            new_row_lines.append([left_line[2], left_line[3], right_line[2], left_line[3]])
+            new_col_y = left_line[3]
+            # 补了row,要将其他短的col连到row上
+            for j in range(len(col_line_list)):
+                col = col_line_list[j]
+                # 且距离不能相差太大
+                if abs(new_col_y - col[3]) <= box_height:
+                    col_line_list[j][3] = max([new_col_y, col[3]])
+        else:
+            new_row_lines.append([left_line[2], right_line[3], right_line[2], right_line[3]])
+            new_col_y = right_line[3]
+            # 补了row,要将其他短的col连到row上
+            for j in range(len(col_line_list)):
+                col = col_line_list[j]
+                # 且距离不能相差太大
+                if abs(new_col_y - col[3]) <= box_height:
+                    col_line_list[j][3] = max([new_col_y, col[3]])
+
+    # 补上下两条横线超出来的线的col
+    if left_line[0] - up_line[0] >= fix_w_len and left_line[0] - bottom_line[0] >= fix_w_len:
+        if left_line[0] - up_line[0] >= left_line[0] - bottom_line[0]:
+            new_col_lines.append([up_line[0], up_line[1], up_line[0], bottom_line[1]])
+            new_row_x = up_line[0]
+            # 补了col,要将其他短的row连到col上
+            for j in range(len(row_line_list)):
+                row = row_line_list[j]
+                # 且距离不能相差太大
+                if abs(new_row_x - row[0]) <= box_width:
+                    row_line_list[j][0] = min([new_row_x, row[0]])
+        else:
+            new_col_lines.append([bottom_line[0], up_line[1], bottom_line[0], bottom_line[1]])
+            new_row_x = bottom_line[0]
+            # 补了col,要将其他短的row连到col上
+            for j in range(len(row_line_list)):
+                row = row_line_list[j]
+                # 且距离不能相差太大
+                if abs(new_row_x - row[0]) <= box_width:
+                    row_line_list[j][0] = min([new_row_x, row[0]])
+    if up_line[2] - right_line[2] >= fix_w_len and bottom_line[2] - right_line[2] >= fix_w_len:
+        if up_line[2] - right_line[2] >= bottom_line[2] - right_line[2]:
+            new_col_lines.append([up_line[2], up_line[3], up_line[2], bottom_line[3]])
+            new_row_x = up_line[2]
+            # 补了col,要将其他短的row连到col上
+            for j in range(len(row_line_list)):
+                row = row_line_list[j]
+                # 且距离不能相差太大
+                if abs(new_row_x - row[2]) <= box_width:
+                    row_line_list[j][2] = max([new_row_x, row[2]])
+        else:
+            new_col_lines.append([bottom_line[2], up_line[3], bottom_line[2], bottom_line[3]])
+            new_row_x = bottom_line[2]
+            # 补了col,要将其他短的row连到col上
+            for j in range(len(row_line_list)):
+                row = row_line_list[j]
+                # 且距离不能相差太大
+                if abs(new_row_x - row[2]) <= box_width:
+                    row_line_list[j][2] = max([new_row_x, row[2]])
+
+    all_longer_row_lines += row_line_list
+    all_longer_col_lines += col_line_list
+
+    # print('new_row_lines, new_col_lines', new_row_lines, new_col_lines)
+    # print('all_longer_row_lines, all_longer_col_lines', all_longer_row_lines, all_longer_col_lines)
+    return new_row_lines, new_col_lines, all_longer_row_lines, all_longer_col_lines
+
+
+def fix_inner(row_line_list, col_line_list, point_list):
+    def fix(fix_lines, assist_lines, split_points, axis):
+        new_line_point_list = []
+        delete_line_point_list = []
+        for line1 in fix_lines:
+            min_assist_line = [[], []]
+            min_distance = [1000, 1000]
+            if_find = [0, 0]
+
+            # 获取fix_line中的所有col point,里面可能不包括两个顶点,col point是交点,顶点可能不是交点
+            fix_line_points = []
+            for point in split_points:
+                if abs(point[1 - axis] - line1[1 - axis]) <= 2:
+                    if line1[axis] <= point[axis] <= line1[axis + 2]:
+                        fix_line_points.append(point)
+
+            # 找出离两个顶点最近的assist_line, 并且assist_line与fix_line不相交
+            line1_point = [line1[:2], line1[2:]]
+            for i in range(2):
+                point = line1_point[i]
+                for line2 in assist_lines:
+                    if not if_find[i] and abs(point[axis] - line2[axis]) <= 2:
+                        if line1[1 - axis] <= point[1 - axis] <= line2[1 - axis + 2]:
+                            # print("line1, match line2", line1, line2)
+                            if_find[i] = 1
+                            break
+                    else:
+                        if abs(point[axis] - line2[axis]) < min_distance[i] and line2[1 - axis] <= point[1 - axis] <= \
+                                line2[1 - axis + 2]:
+                            if line1[axis] <= line2[axis] <= line1[axis + 2]:
+                                continue
+                            min_distance[i] = abs(line1[axis] - line2[axis])
+                            min_assist_line[i] = line2
+
+            if len(min_assist_line[0]) == 0 and len(min_assist_line[1]) == 0:
+                continue
+
+            # 找出离assist_line最近的交点
+            min_distance = [1000, 1000]
+            min_col_point = [[], []]
+            for i in range(2):
+                # print("顶点", i, line1_point[i])
+                if min_assist_line[i]:
+                    for point in fix_line_points:
+                        if abs(point[axis] - min_assist_line[i][axis]) < min_distance[i]:
+                            min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
+                            min_col_point[i] = point
+
+            # print("min_col_point", min_col_point)
+            # print("min_assist_line", min_assist_line)
+
+            if len(min_col_point[0]) == 0 and len(min_col_point[1]) == 0:
+                continue
+
+            # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
+            # print("line1_point", line1_point)
+            if min_assist_line[0] and min_assist_line[0] == min_assist_line[1]:
+                if min_assist_line[0][axis] < line1_point[0][axis]:
+                    bbox_len = abs(min_col_point[0][axis] - min_assist_line[0][axis])
+                    line_distance = abs(min_col_point[0][axis] - line1_point[0][axis])
+                    if bbox_len / 3 <= line_distance <= bbox_len:
+                        if axis == 1:
+                            add_point = (line1_point[0][1 - axis], min_assist_line[0][axis])
+                        else:
+                            add_point = (min_assist_line[0][axis], line1_point[0][1 - axis])
+                        new_line_point_list.append([line1, add_point])
+                elif min_assist_line[1][axis] > line1_point[1][axis]:
+                    bbox_len = abs(min_col_point[1][axis] - min_assist_line[1][axis])
+                    line_distance = abs(min_col_point[1][axis] - line1_point[1][axis])
+                    if bbox_len / 3 <= line_distance <= bbox_len:
+                        if axis == 1:
+                            add_point = (line1_point[1][1 - axis], min_assist_line[1][axis])
+                        else:
+                            add_point = (min_assist_line[1][axis], line1_point[1][1 - axis])
+                        new_line_point_list.append([line1, add_point])
+            else:
+                for i in range(2):
+                    if min_col_point[i]:
+                        bbox_len = abs(min_col_point[i][axis] - min_assist_line[i][axis])
+                        line_distance = abs(min_col_point[i][axis] - line1_point[i][axis])
+                        # print("bbox_len, line_distance", bbox_len, line_distance)
+                        if bbox_len / 3 <= line_distance <= bbox_len:
+                            if axis == 1:
+                                add_point = (line1_point[i][1 - axis], min_assist_line[i][axis])
+                            else:
+                                add_point = (min_assist_line[i][axis], line1_point[i][1 - axis])
+                            new_line_point_list.append([line1, add_point])
+
+        return new_line_point_list
+
+    row_line_list_copy = copy.deepcopy(row_line_list)
+    col_line_list_copy = copy.deepcopy(col_line_list)
+    try:
+        new_point_list = fix(col_line_list, row_line_list, point_list, axis=1)
+        for line, new_point in new_point_list:
+            if line in col_line_list:
+                index = col_line_list.index(line)
+                point1 = line[:2]
+                point2 = line[2:]
+                if new_point[1] >= point2[1]:
+                    col_line_list[index] = [point1[0], point1[1], new_point[0], new_point[1]]
+                elif new_point[1] <= point1[1]:
+                    col_line_list[index] = [new_point[0], new_point[1], point2[0], point2[1]]
+
+        new_point_list = fix(row_line_list, col_line_list, point_list, axis=0)
+        for line, new_point in new_point_list:
+            if line in row_line_list:
+                index = row_line_list.index(line)
+                point1 = line[:2]
+                point2 = line[2:]
+                if new_point[0] >= point2[0]:
+                    row_line_list[index] = [point1[0], point1[1], new_point[0], new_point[1]]
+                elif new_point[0] <= point1[0]:
+                    row_line_list[index] = [new_point[0], new_point[1], point2[0], point2[1]]
+        return row_line_list, col_line_list
+    except:
+        traceback.print_exc()
+        return row_line_list_copy, col_line_list_copy
+
+
+def fix_4_points(cross_points, row_line_list, col_line_list):
+    if not (len(row_line_list) >= 2 and len(col_line_list) >= 2):
+        return row_line_list, col_line_list
+
+    cross_points.sort(key=lambda x: (x[0], x[1]))
+    left_up_p = cross_points[0]
+    right_down_p = cross_points[-1]
+    cross_points.sort(key=lambda x: (-x[0], x[1]))
+    right_up_p = cross_points[0]
+    left_down_p = cross_points[-1]
+    # print('left_up_p', left_up_p, 'left_down_p', left_down_p)
+    # print('right_up_p', right_up_p, 'right_down_p', right_down_p)
+
+    min_x = min(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
+    max_x = max(left_up_p[0], left_down_p[0], right_down_p[0], right_up_p[0])
+    min_y = min(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
+    max_y = max(left_up_p[1], left_down_p[1], right_down_p[1], right_up_p[1])
+
+    if left_up_p[0] != min_x or left_up_p[1] != min_y:
+        log('轮廓左上角交点有问题')
+        row_line_list.append([min_x, min_y, max_x, min_y])
+        col_line_list.append([min_x, min_y, min_x, max_y])
+    if left_down_p[0] != min_x or left_down_p[1] != max_y:
+        log('轮廓左下角交点有问题')
+        row_line_list.append([min_x, max_y, max_x, max_y])
+        col_line_list.append([min_x, min_y, min_x, max_y])
+    if right_up_p[0] != max_x or right_up_p[1] != min_y:
+        log('轮廓右上角交点有问题')
+        row_line_list.append([min_x, max_y, max_x, max_y])
+        col_line_list.append([max_x, min_y, max_x, max_y])
+    if right_down_p[0] != max_x or right_down_p[1] != max_y:
+        log('轮廓右下角交点有问题')
+        row_line_list.append([min_x, max_y, max_x, max_y])
+        col_line_list.append([max_x, min_y, max_x, max_y])
+
+    return row_line_list, col_line_list
+
+
+def get_split_line(points, col_lines, image_np, threshold=5):
+    # 线贴着边缘无法得到split_y,导致无法分区
+    for _col in col_lines:
+        if _col[3] >= image_np.shape[0] - 5:
+            _col[3] = image_np.shape[0] - 6
+
+        if _col[1] <= 0 + 5:
+            _col[1] = 6
+
+    # print("get_split_line", image_np.shape)
+    points.sort(key=lambda x: (x[1], x[0]))
+    # 遍历y坐标,并判断y坐标与上一个y坐标是否存在连接线
+    i = 0
+    split_line_y = []
+    for point in points:
+        # 从已分开的线下面开始判断
+        if split_line_y:
+            if point[1] <= split_line_y[-1] + threshold:
+                last_y = point[1]
+                continue
+            if last_y <= split_line_y[-1] + threshold:
+                last_y = point[1]
+                continue
+
+        if i == 0:
+            last_y = point[1]
+            i += 1
+            continue
+
+        current_line = (last_y, point[1])
+        split_flag = 1
+        for col in col_lines:
+            # 只要找到一条col包含就不是分割线
+            if current_line[0] >= col[1] - 3 and current_line[1] <= col[3] + 3:
+                split_flag = 0
+                break
+
+        if split_flag:
+            split_line_y.append(current_line[0] + 5)
+            split_line_y.append(current_line[1] - 5)
+
+        last_y = point[1]
+
+    # 加上收尾分割线
+    points.sort(key=lambda x: (x[1], x[0]))
+    y_min = points[0][1]
+    y_max = points[-1][1]
+    if y_min - threshold < 0:
+        split_line_y.append(0)
+    else:
+        split_line_y.append(y_min - threshold)
+    if y_max + threshold > image_np.shape[0]:
+        split_line_y.append(image_np.shape[0])
+    else:
+        split_line_y.append(y_max + threshold)
+    split_line_y = list(set(split_line_y))
+
+    # 剔除两条相隔太近分割线
+    temp_split_line_y = []
+    split_line_y.sort(key=lambda x: x)
+    last_y = -20
+    for y in split_line_y:
+        if y - last_y >= 20:
+            temp_split_line_y.append(y)
+            last_y = y
+    split_line_y = temp_split_line_y
+
+    # 生成分割线
+    split_line = []
+    for y in split_line_y:
+        split_line.append([(0, y), (image_np.shape[1], y)])
+    split_line.append([(0, 0), (image_np.shape[1], 0)])
+    split_line.append([(0, image_np.shape[0]), (image_np.shape[1], image_np.shape[0])])
+    split_line.sort(key=lambda x: x[0][1])
+    return split_line, split_line_y
+
+
+def get_split_area(split_y, row_line_list, col_line_list, cross_points):
+    # 分割线纵坐标
+    if len(split_y) < 2:
+        return [], [], []
+
+    split_y.sort(key=lambda x: x)
+    # new_split_y = []
+    # for i in range(1, len(split_y), 2):
+    #     new_split_y.append(int((split_y[i] + split_y[i - 1]) / 2))
+
+    area_row_line_list = []
+    area_col_line_list = []
+    area_point_list = []
+
+    for i in range(1, len(split_y)):
+        y = split_y[i]
+        last_y = split_y[i - 1]
+
+        split_row = []
+        for row in row_line_list:
+            if last_y <= row[3] <= y:
+                split_row.append(row)
+
+        split_col = []
+        for col in col_line_list:
+            if last_y <= col[1] <= y or last_y <= col[3] <= y or col[1] < last_y < y < col[3]:
+                split_col.append(col)
+
+        split_point = []
+        for point in cross_points:
+            if last_y <= point[1] <= y:
+                split_point.append(point)
+
+        # 满足条件才能形成表格区域
+        if len(split_row) >= 2 and len(split_col) >= 2 and len(split_point) >= 4:
+            # print('len(split_row), len(split_col), len(split_point)', len(split_row), len(split_col), len(split_point))
+            area_row_line_list.append(split_row)
+            area_col_line_list.append(split_col)
+            area_point_list.append(split_point)
+
+    return area_row_line_list, area_col_line_list, area_point_list
+
+
+def get_standard_lines(row_line_list, col_line_list):
+    new_row_line_list = []
+    for row in row_line_list:
+        w1 = row[0]
+        w2 = row[2]
+        # 横线的两个顶点分别找到最近的竖线
+        min_distance = [10000, 10000]
+        min_dis_w = [None, None]
+        for col in col_line_list:
+            if abs(col[0] - w1) < min_distance[0]:
+                min_distance[0] = abs(col[0] - w1)
+                min_dis_w[0] = col[0]
+            if abs(col[0] - w2) < min_distance[1]:
+                min_distance[1] = abs(col[0] - w2)
+                min_dis_w[1] = col[0]
+        if min_dis_w[0] is not None:
+            row[0] = min_dis_w[0]
+        if min_dis_w[1] is not None:
+            row[2] = min_dis_w[1]
+        new_row_line_list.append(row)
+
+    new_col_line_list = []
+    for col in col_line_list:
+        h1 = col[1]
+        h2 = col[3]
+        # 横线的两个顶点分别找到最近的竖线
+        min_distance = [10000, 10000]
+        min_dis_w = [None, None]
+        for row in row_line_list:
+            if abs(row[1] - h1) < min_distance[0]:
+                min_distance[0] = abs(row[1] - h1)
+                min_dis_w[0] = row[1]
+            if abs(row[1] - h2) < min_distance[1]:
+                min_distance[1] = abs(row[1] - h2)
+                min_dis_w[1] = row[1]
+        if min_dis_w[0] is not None:
+            col[1] = min_dis_w[0]
+        if min_dis_w[1] is not None:
+            col[3] = min_dis_w[1]
+        new_col_line_list.append(col)
+
+
+    # all_line_list = []
+    # # 横线竖线两个维度
+    # for i in range(2):
+    #     axis = i
+    #     cross_points.sort(key=lambda x: (x[axis], x[1-axis]))
+    #     current_axis = cross_points[0][axis]
+    #     points = []
+    #     line_list = []
+    #     for p in cross_points:
+    #         if p[axis] == current_axis:
+    #             points.append(p)
+    #         else:
+    #             if points:
+    #                 line_list.append([points[0][0], points[0][1], points[-1][0], points[-1][1]])
+    #             points = [p]
+    #             current_axis = p[axis]
+    #     if points:
+    #         line_list.append([points[0][0], points[0][1], points[-1][0], points[-1][1]])
+    #     all_line_list.append(line_list)
+    # new_col_line_list, new_row_line_list = all_line_list
+
+    return new_col_line_list, new_row_line_list
+
+
+def add_outline(cross_points, row_line_list, col_line_list):
+    cross_points.sort(key=lambda x: (x[0], x[1]))
+    left_up_p = cross_points[0]
+    right_down_p = cross_points[-1]
+
+    row_line_list.append([left_up_p[0], left_up_p[1], right_down_p[0], left_up_p[1]])
+    row_line_list.append([left_up_p[0], right_down_p[1], right_down_p[0], right_down_p[1]])
+    col_line_list.append([left_up_p[0], left_up_p[1], left_up_p[0], right_down_p[1]])
+    col_line_list.append([right_down_p[0], left_up_p[1], right_down_p[0], right_down_p[1]])
+    return row_line_list, col_line_list
+