Explorar el Código

优化无边框表格分行

fangjiasheng hace 1 año
padre
commit
8e58fe398f
Se han modificado 3 ficheros con 236 adiciones y 229 borrados
  1. 3 1
      botr/extract_table.py
  2. 7 2
      botr/rules/table_utils.py
  3. 226 226
      otr/otr_interface.py

+ 3 - 1
botr/extract_table.py

@@ -53,7 +53,9 @@ def b_table_process(list_line, list_text_boxes, list_cell, table_location):
             # 对TextBox进行分行,否则同样一行有些框偏上有些偏下,影响文本顺序
             area_list_text_boxes.sort(key=lambda x: (x.bbox[1], x.bbox[0], x.bbox[3], x.bbox[2]))
             current_y = area_list_text_boxes[0].bbox[1]
-            threshold = 2.
+            current_y2 = area_list_text_boxes[0].bbox[3]
+            # threshold = 2.
+            threshold = max(2., 1/3 * abs(current_y2 - current_y))
             for t_b in area_list_text_boxes:
                 bbox = t_b.bbox
                 if current_y - threshold <= bbox[1] <= current_y + threshold:

+ 7 - 2
botr/rules/table_utils.py

@@ -1,5 +1,10 @@
+import os
+import sys
+
 import numpy as np
 import cv2
+sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
+from format_convert.utils import log
 
 
 def shrink_bbox(img, bbox_list):
@@ -149,9 +154,9 @@ def get_points_by_line(img, row_lines, col_lines):
     row_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
     col_img = np.zeros_like(img[:, :, 0], dtype=np.uint8)
     for r in row_lines:
-        cv2.line(row_img, [r[0], r[1]], [r[2], r[3]], (255, 255, 255), 1)
+        cv2.line(row_img, (r[0], r[1]), (r[2], r[3]), (255, 255, 255), 1)
     for c in col_lines:
-        cv2.line(col_img, [c[0], c[1]], [c[2], c[3]], (255, 255, 255), 1)
+        cv2.line(col_img, (c[0], c[1]), (c[2], c[3]), (255, 255, 255), 1)
 
     point_img = np.bitwise_and(row_img, col_img)
 

+ 226 - 226
otr/otr_interface.py

@@ -38,10 +38,7 @@ import cv2
 import numpy as np
 from flask import Flask, request
 from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform
-from otr.table_line import get_points, get_split_line, get_points_row, \
-    get_points_col, \
-    delete_close_points, fix_outline, get_bbox, get_outline_point, delete_contain_bbox, points_to_line, \
-    fix_inner, merge_line, fix_corner, delete_outline, table_net
+from otr.table_line import table_net
 from otr.table_line_new import table_line
 from format_convert import _global
 
@@ -88,240 +85,243 @@ def otr(data, otr_model, is_from_pdf):
     log('otr is_from_pdf ' + str(is_from_pdf))
     try:
         img_data = base64.b64decode(data)
-        if is_from_pdf:
+        if str(is_from_pdf) == 'True':
+            # log('otr is_from_pdf ' + str(type(is_from_pdf)))
             list_lines = line_detect(img_data, otr_model, prob=0.2)
         else:
+            # log('otr is_from_pdf ' + str(type(is_from_pdf)))
             list_lines = line_detect(img_data, otr_model, prob=0.5)
         return list_lines
     except TimeoutError:
         raise TimeoutError
 
 
-def table_detect2(img_data, otr_model):
-    log("into otr_interface table_detect")
-    start_time = time.time()
-    try:
-        start_time1 = time.time()
-        # 二进制数据流转np.ndarray [np.uint8: 8位像素]
-        img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
-        # log("into otr_interface table_detect 1")
-        # cv2.imwrite("111111.jpg", img)
-
-        # 将bgr转为rbg
-        image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-        # log("into otr_interface table_detect 2")
-
-        # 选择与图片最接近分辨率,以防失真
-        # best_h, best_w = get_best_predict_size(img)
-        print("image_np.shape", image_np.shape)
-        best_h, best_w, _ = image_np.shape
-        log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
-
-        # 调用模型
-        # rows, cols = table_line(image_np, otr_model)
-        rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
-
-        start_time1 = time.time()
-        if not rows or not cols:
-            print("points", 0, "split_lines", 0, "bboxes", 0)
-            return {"points": str([]), "split_lines": str([]),
-                    "bboxes": str([]), "outline_points": str([]),
-                    "lines": str([])}
-
-        # 查看是否正确输出rows,cols
-        # for line in rows+cols:
-        #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
-        #              (255, 0, 0), 2)
-        #     cv2.imshow("rows-cols1", img)
-        #     cv2.waitKey(0)
-
-        # 处理结果
-        # 合并错开线
-        rows = merge_line(rows, axis=0)
-        cols = merge_line(cols, axis=1)
-
-        # 计算交点、分割线
-        points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
-        # log("into otr_interface table_detect 5")
-        if not points:
-            print("points", 0, "split_lines", 0, "bboxes", 0)
-            return {"points": str([]), "split_lines": str([]),
-                    "bboxes": str([]), "outline_points": str([]),
-                    "lines": str([])}
-
-        # 清掉外围的没用的线
-        rows, cols = delete_outline(rows, cols, points)
-
-        split_lines, split_y = get_split_line(points, cols, image_np)
-        # log("into otr_interface table_detect 6")
-
-        # 计算交点所在行列,剔除相近交点
-        row_point_list = get_points_row(points, split_y, 5)
-        col_point_list = get_points_col(points, split_y, 5)
-        # log("into otr_interface table_detect 7")
-        points = delete_close_points(points, row_point_list, col_point_list)
-        # log("into otr_interface table_detect 8")
-
-        # 查看是否正确输出点
-        # for p in points:
-        #     cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
-        #     cv2.imshow("points", img)
-        #     cv2.waitKey(0)
-
-        # 查看是否正确输出rows,cols
-        # for line in rows+cols:
-        #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
-        #              (0, 255, 0), 2)
-        #     cv2.imshow("rows-cols0", img)
-        #     cv2.waitKey(0)
-
-        # 修复边框
-        new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
-                                                               split_y)
-        # print(new_cols, new_rows)
-        if new_rows or new_cols:
-            # 连接至补线的延长线
-            if long_rows:
-                rows = long_rows
-            if long_cols:
-                cols = long_cols
-            # 新的补线
-            if new_rows:
-                rows += new_rows
-            if new_cols:
-                cols += new_cols
-
-            # 修复边框后重新计算交点、分割线
-            points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
-
-            # log("into otr_interface table_detect 10")
-            split_lines, split_y = get_split_line(points, cols, image_np)
-
-            # 计算交点所在行列,剔除相近交点
-            row_point_list = get_points_row(points, split_y, 0)
-            col_point_list = get_points_col(points, split_y, 0)
-            # log("into otr_interface table_detect 11")
-            points = delete_close_points(points, row_point_list, col_point_list)
-            # row_point_list = get_points_row(points, split_y)
-            # col_point_list = get_points_col(points, split_y)
-            # log("into otr_interface table_detect 12")
-
-        # 查看是否正确输出rows,cols
-        # for line in rows+cols:
-        #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
-        #              (255, 0, 0), 2)
-        #     cv2.imshow("rows-cols1", img)
-        #     cv2.waitKey(0)
-
-        # 修复表格4个角
-        rows, cols = fix_corner(rows, cols, split_y)
-        points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
-        # row_point_list = get_points_row(points, split_y, 5)
-        # col_point_list = get_points_col(points, split_y, 5)
-
-        # print("row_point_list", row_point_list)
-        # print("col_point_list", col_point_list)
-
-        # 修复内部缺线
-        points = fix_inner(rows, cols, points, split_y)
-        if not points:
-            print("points", 0, "split_lines", 0, "bboxes", 0)
-            return {"points": str([]), "split_lines": str([]),
-                    "bboxes": str([]), "outline_points": str([]),
-                    "lines": str([])}
-        row_point_list = get_points_row(points, split_y, 5)
-        col_point_list = get_points_col(points, split_y, 5)
-
-        # 查看是否正确输出点
-        # for p in points:
-        #     cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
-        #     cv2.imshow("points fix", img)
-        #     cv2.waitKey(0)
-
-        # 查看是否正确输出rows,cols
-        # for line in rows+cols:
-        #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
-        #              (255, 0, 0), 2)
-        #     cv2.imshow("rows-cols2", img)
-        #     cv2.waitKey(0)
-
-        # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
-        # rows = points_to_line(row_point_list, axis=0)
-        # cols = points_to_line(col_point_list, axis=1)
-        # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
-        # row_point_list = get_points_row(points, split_y, 0)
-        # col_point_list = get_points_col(points, split_y, 0)
-
-        # 获取bbox 单元格
-        bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
-        # log("into otr_interface table_detect 13")
-
-        # 删除包含bbox
-        if bboxes:
-            bboxes = delete_contain_bbox(bboxes)
-
-        # 查看是否能输出正确框
-        # for box in bboxes:
-        #     cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
-        #     cv2.imshow("bbox", img)
-        #     cv2.waitKey(0)
-
-        # 补充连续框
-        # if bboxes:
-        #     bboxes = add_continue_bbox(bboxes)
-        #
-        #     # 删除包含bbox
-        #     bboxes = delete_contain_bbox(bboxes)
-
-        # 查看是否能输出正确框
-        # cv2.namedWindow('bbox', 0)
-        # for box in bboxes:
-        #     cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
-        #     cv2.imshow("bbox", img)
-        #     cv2.waitKey(0)
-
-        # 查看是否正确输出点
-        # cv2.namedWindow('points', 0)
-        # for p in points:
-        #     cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
-        #     cv2.imshow("points", img)
-        #     cv2.waitKey(0)
-
-        # 查看是否正确输出区域分割线
-        # cv2.namedWindow('split_lines', 0)
-        # for line in split_lines:
-        #     cv2.line(img, line[0], line[1], (0, 0, 255), 2)
-        #     cv2.imshow("split_lines", img)
-        #     cv2.waitKey(0)
-
-        # 获取每个表格的左上右下两个点
-        outline_points = get_outline_point(points, split_y)
-        # log("into otr_interface table_detect 14")
-
-        if bboxes:
-            print("bboxes number", len(bboxes))
-            # print("bboxes", bboxes)
-        else:
-            print("bboxes number", "None")
-        log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
-        log("otr finish: " + str(round(float(time.time()-start_time1), 4)) + "s")
-        return {"points": str(points), "split_lines": str(split_lines),
-                "bboxes": str(bboxes), "outline_points": str(outline_points),
-                "lines": str(rows+cols)}
-
-    except TimeoutError:
-        raise TimeoutError
-    except Exception as e:
-        log("otr_interface cannot detected table!")
-        print("otr_interface cannot detected table!", traceback.print_exc())
-        print("points", 0, "split_lines", 0, "bboxes", 0)
-        log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
-        return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
-                "outline_points": str([]), "lines": str([])}
+# def table_detect2(img_data, otr_model):
+#     log("into otr_interface table_detect")
+#     start_time = time.time()
+#     try:
+#         start_time1 = time.time()
+#         # 二进制数据流转np.ndarray [np.uint8: 8位像素]
+#         img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
+#         # log("into otr_interface table_detect 1")
+#         # cv2.imwrite("111111.jpg", img)
+#
+#         # 将bgr转为rbg
+#         image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+#         # log("into otr_interface table_detect 2")
+#
+#         # 选择与图片最接近分辨率,以防失真
+#         # best_h, best_w = get_best_predict_size(img)
+#         print("image_np.shape", image_np.shape)
+#         best_h, best_w, _ = image_np.shape
+#         log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
+#
+#         # 调用模型
+#         # rows, cols = table_line(image_np, otr_model)
+#         rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
+#
+#         start_time1 = time.time()
+#         if not rows or not cols:
+#             print("points", 0, "split_lines", 0, "bboxes", 0)
+#             return {"points": str([]), "split_lines": str([]),
+#                     "bboxes": str([]), "outline_points": str([]),
+#                     "lines": str([])}
+#
+#         # 查看是否正确输出rows,cols
+#         # for line in rows+cols:
+#         #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
+#         #              (255, 0, 0), 2)
+#         #     cv2.imshow("rows-cols1", img)
+#         #     cv2.waitKey(0)
+#
+#         # 处理结果
+#         # 合并错开线
+#         rows = merge_line(rows, axis=0)
+#         cols = merge_line(cols, axis=1)
+#
+#         # 计算交点、分割线
+#         points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
+#         # log("into otr_interface table_detect 5")
+#         if not points:
+#             print("points", 0, "split_lines", 0, "bboxes", 0)
+#             return {"points": str([]), "split_lines": str([]),
+#                     "bboxes": str([]), "outline_points": str([]),
+#                     "lines": str([])}
+#
+#         # 清掉外围的没用的线
+#         rows, cols = delete_outline(rows, cols, points)
+#
+#         split_lines, split_y = get_split_line(points, cols, image_np)
+#         # log("into otr_interface table_detect 6")
+#
+#         # 计算交点所在行列,剔除相近交点
+#         row_point_list = get_points_row(points, split_y, 5)
+#         col_point_list = get_points_col(points, split_y, 5)
+#         # log("into otr_interface table_detect 7")
+#         points = delete_close_points(points, row_point_list, col_point_list)
+#         # log("into otr_interface table_detect 8")
+#
+#         # 查看是否正确输出点
+#         # for p in points:
+#         #     cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
+#         #     cv2.imshow("points", img)
+#         #     cv2.waitKey(0)
+#
+#         # 查看是否正确输出rows,cols
+#         # for line in rows+cols:
+#         #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
+#         #              (0, 255, 0), 2)
+#         #     cv2.imshow("rows-cols0", img)
+#         #     cv2.waitKey(0)
+#
+#         # 修复边框
+#         new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
+#                                                                split_y)
+#         # print(new_cols, new_rows)
+#         if new_rows or new_cols:
+#             # 连接至补线的延长线
+#             if long_rows:
+#                 rows = long_rows
+#             if long_cols:
+#                 cols = long_cols
+#             # 新的补线
+#             if new_rows:
+#                 rows += new_rows
+#             if new_cols:
+#                 cols += new_cols
+#
+#             # 修复边框后重新计算交点、分割线
+#             points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
+#
+#             # log("into otr_interface table_detect 10")
+#             split_lines, split_y = get_split_line(points, cols, image_np)
+#
+#             # 计算交点所在行列,剔除相近交点
+#             row_point_list = get_points_row(points, split_y, 0)
+#             col_point_list = get_points_col(points, split_y, 0)
+#             # log("into otr_interface table_detect 11")
+#             points = delete_close_points(points, row_point_list, col_point_list)
+#             # row_point_list = get_points_row(points, split_y)
+#             # col_point_list = get_points_col(points, split_y)
+#             # log("into otr_interface table_detect 12")
+#
+#         # 查看是否正确输出rows,cols
+#         # for line in rows+cols:
+#         #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
+#         #              (255, 0, 0), 2)
+#         #     cv2.imshow("rows-cols1", img)
+#         #     cv2.waitKey(0)
+#
+#         # 修复表格4个角
+#         rows, cols = fix_corner(rows, cols, split_y)
+#         points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
+#         # row_point_list = get_points_row(points, split_y, 5)
+#         # col_point_list = get_points_col(points, split_y, 5)
+#
+#         # print("row_point_list", row_point_list)
+#         # print("col_point_list", col_point_list)
+#
+#         # 修复内部缺线
+#         points = fix_inner(rows, cols, points, split_y)
+#         if not points:
+#             print("points", 0, "split_lines", 0, "bboxes", 0)
+#             return {"points": str([]), "split_lines": str([]),
+#                     "bboxes": str([]), "outline_points": str([]),
+#                     "lines": str([])}
+#         row_point_list = get_points_row(points, split_y, 5)
+#         col_point_list = get_points_col(points, split_y, 5)
+#
+#         # 查看是否正确输出点
+#         # for p in points:
+#         #     cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
+#         #     cv2.imshow("points fix", img)
+#         #     cv2.waitKey(0)
+#
+#         # 查看是否正确输出rows,cols
+#         # for line in rows+cols:
+#         #     cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
+#         #              (255, 0, 0), 2)
+#         #     cv2.imshow("rows-cols2", img)
+#         #     cv2.waitKey(0)
+#
+#         # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
+#         # rows = points_to_line(row_point_list, axis=0)
+#         # cols = points_to_line(col_point_list, axis=1)
+#         # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
+#         # row_point_list = get_points_row(points, split_y, 0)
+#         # col_point_list = get_points_col(points, split_y, 0)
+#
+#         # 获取bbox 单元格
+#         bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
+#         # log("into otr_interface table_detect 13")
+#
+#         # 删除包含bbox
+#         if bboxes:
+#             bboxes = delete_contain_bbox(bboxes)
+#
+#         # 查看是否能输出正确框
+#         # for box in bboxes:
+#         #     cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
+#         #     cv2.imshow("bbox", img)
+#         #     cv2.waitKey(0)
+#
+#         # 补充连续框
+#         # if bboxes:
+#         #     bboxes = add_continue_bbox(bboxes)
+#         #
+#         #     # 删除包含bbox
+#         #     bboxes = delete_contain_bbox(bboxes)
+#
+#         # 查看是否能输出正确框
+#         # cv2.namedWindow('bbox', 0)
+#         # for box in bboxes:
+#         #     cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
+#         #     cv2.imshow("bbox", img)
+#         #     cv2.waitKey(0)
+#
+#         # 查看是否正确输出点
+#         # cv2.namedWindow('points', 0)
+#         # for p in points:
+#         #     cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
+#         #     cv2.imshow("points", img)
+#         #     cv2.waitKey(0)
+#
+#         # 查看是否正确输出区域分割线
+#         # cv2.namedWindow('split_lines', 0)
+#         # for line in split_lines:
+#         #     cv2.line(img, line[0], line[1], (0, 0, 255), 2)
+#         #     cv2.imshow("split_lines", img)
+#         #     cv2.waitKey(0)
+#
+#         # 获取每个表格的左上右下两个点
+#         outline_points = get_outline_point(points, split_y)
+#         # log("into otr_interface table_detect 14")
+#
+#         if bboxes:
+#             print("bboxes number", len(bboxes))
+#             # print("bboxes", bboxes)
+#         else:
+#             print("bboxes number", "None")
+#         log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
+#         log("otr finish: " + str(round(float(time.time()-start_time1), 4)) + "s")
+#         return {"points": str(points), "split_lines": str(split_lines),
+#                 "bboxes": str(bboxes), "outline_points": str(outline_points),
+#                 "lines": str(rows+cols)}
+#
+#     except TimeoutError:
+#         raise TimeoutError
+#     except Exception as e:
+#         log("otr_interface cannot detected table!")
+#         print("otr_interface cannot detected table!", traceback.print_exc())
+#         print("points", 0, "split_lines", 0, "bboxes", 0)
+#         log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
+#         return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
+#                 "outline_points": str([]), "lines": str([])}
 
 
 def line_detect(img_data, otr_model, prob=0.2):
     log("into otr_interface table_detect")
+    log('line_detect prob ' + str(prob))
     start_time = time.time()
     try:
         start_time1 = time.time()