fangjiasheng 5 месяцев назад
Родитель
Сommit
dd137a4c5b
3 измененных файлов с 197 добавлено и 8 удалено
  1. 3 0
      format_convert/convert_pdf.py
  2. 39 3
      ocr/ocr_interface.py
  3. 155 5
      otr/table_line_pdf.py

+ 3 - 0
format_convert/convert_pdf.py

@@ -117,6 +117,9 @@ class PDFConvert:
         # 是否提取特殊页
         self.convert_specific_page = 1
 
+        # 初始化_page
+        self._page = _Page(None, 0)
+
     @memory_decorator
     def init_package(self, package_name):
         # 各个包初始化

+ 39 - 3
ocr/ocr_interface.py

@@ -1,9 +1,13 @@
 import base64
+import io
 import json
 import multiprocessing as mp
 import socket
 import sys
 import os
+
+from PIL import Image
+
 sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
 import time
 import traceback
@@ -142,16 +146,25 @@ def test_ocr_model(from_remote=True):
     file_path = "C:/Users/Administrator/Downloads/dbf46fe38862ac03209f1b2c12b1adc1.jpg"
     with open(file_path, "rb") as f:
         file_bytes = f.read()
+
+    src = """
+        
+    """
+    image_data = src.split('data:image/png;base64,')[1]
+    # 解码 base64 字符串
+    file_bytes = base64.b64decode(image_data)
+
     file_base64 = base64.b64encode(file_bytes)
     _md5 = get_md5_from_bytes(file_bytes)[0]
-    only_rec = False
+    only_rec = 0
     _global._init()
     _global.update({"port": 15010, "md5": _md5})
 
     if from_remote:
         file_json = {"data": file_base64, "md5": _md5, 'only_rec': only_rec}
         # _url = "http://192.168.2.102:17000/ocr"
-        _url = "http://127.0.0.1:17000/ocr"
+        # _url = "http://127.0.0.1:17000/ocr"
+        _url = "http://120.132.118.205:17000/ocr"
         print(json.loads(request_post(_url, file_json)))
     else:
         ocr_model = OcrModels().get_model()
@@ -163,4 +176,27 @@ def test_ocr_model(from_remote=True):
 
 
 if __name__ == '__main__':
-    test_ocr_model(False)
+    test_ocr_model()
+
+#     src = """
+# 
+# """
+#
+#     image_data = src.split('data:image/png;base64,')[1]
+#
+#     # 解码 base64 字符串
+#     image_bytes = base64.b64decode(image_data)
+#
+#     # 将字节转换为图像
+#     # image = Image.open(io.BytesIO(image_bytes))
+#
+#     # image.show('img')
+#
+#     # with open(r'C:\Users\Administrator\Desktop\test_image\error16.jpg', 'rb') as f:
+#     #     image_bytes = f.read()
+#
+#     image = bytes2np(image_bytes)
+#
+#     cv2.imshow('img', image)
+#     cv2.imwrite('./1.png', image)
+#     cv2.waitKey(0)

+ 155 - 5
otr/table_line_pdf.py

@@ -26,6 +26,9 @@ def _plot(_line_list, title, mode=1, show=1):
             x0, y0, x1, y1 = _line.__dict__.get("bbox")
         elif mode == 2:
             x0, y0, x1, y1 = _line
+        if max(x0, y0, x1, y1) >= 10000:
+            print('not show line', _line)
+            continue
         plt.plot([x0, x1], [y0, y1])
     plt.title(title)
     plt.show()
@@ -268,6 +271,121 @@ def merge_line(_line_list, threshold=2):
     return new_line_list
 
 
+def merge_extend_line(_line_list, threshold=2):
+    """
+    暂时不用
+    :param _line_list:
+    :param threshold:
+    :return:
+    """
+    start_time = time.time()
+
+    new_line_list = []
+    row_line_list = []
+    col_line_list = []
+
+    _line_list.sort(key=lambda x: (x[1], x[0]))
+    for line in _line_list:
+        if abs(line[0] - line[2]) < abs(line[1] - line[3]):
+            continue
+        row_line_list.append(line)
+    _line_list.sort(key=lambda x: (x[0], x[1]))
+    for line in _line_list:
+        if abs(line[0] - line[2]) > abs(line[1] - line[3]):
+            continue
+        col_line_list.append(line)
+
+    # 合并竖线
+    cols = []
+    col = []
+    current_w = None
+    for line in _line_list:
+        if abs(line[0] - line[2]) > abs(line[1] - line[3]):
+            continue
+        if not col:
+            col.append(line)
+            current_w = line[0]
+
+        _iou = line_iou([[0, line[1]], [0, line[3]]], [[0, col[0][1]], [0, col[0][3]]], axis=1)
+        if min(line[0], line[2]) - threshold <= current_w <= max(line[0], line[2]) + threshold \
+                and is_cross(line[0:2], line[2:4], col[-1][0:2], col[-1][2:4]):
+            col.append(line)
+        elif min(line[0], line[2]) - 2*threshold <= current_w <= max(line[0], line[2]) + 2*threshold \
+                and _iou >= 0.1:
+            col.append(line)
+        else:
+            if col:
+                cols.append(col)
+            col = [line]
+            current_w = line[0]
+    if col:
+        cols.append(col)
+
+    for col in cols:
+        temp_c = col[0]
+        col_w = col[0][0]
+        for i in range(len(col) - 1):
+            c = col[i]
+            next_c = col[i + 1]
+            if is_cross(c[0:2], c[2:4], next_c[0:2], next_c[2:4]) \
+                    or line_iou([[0, c[1]], [0, c[3]]], [[0, next_c[1]], [0, next_c[3]]], axis=1) >= 0.1:
+                temp_c = [col_w, min(temp_c[1], c[1], c[3], next_c[1], next_c[3]), col_w,
+                          max(temp_c[3], c[1], c[3], next_c[1], next_c[3])]
+            else:
+                new_line_list.append(temp_c)
+                temp_c = next_c
+        if not new_line_list or (new_line_list and new_line_list[-1] != temp_c):
+            new_line_list.append(temp_c)
+
+    # 横线 分行
+    _line_list.sort(key=lambda x: (x[1], x[0]))
+    rows = []
+    row = []
+    current_h = None
+    for line in _line_list:
+        if abs(line[0] - line[2]) < abs(line[1] - line[3]):
+            continue
+
+        if not row:
+            row = [line]
+            current_h = line[1]
+
+        if min(line[1], line[3]) - threshold <= current_h <= max(line[1], line[3]) + threshold:
+            row.append(line)
+        else:
+            if row:
+                rows.append(row)
+            row = [line]
+            current_h = line[1]
+    if row:
+        rows.append(row)
+
+    # 横线 合并
+    for row in rows:
+        temp_r = row[0]
+        row_h = row[0][1]
+        for i in range(len(row) - 1):
+            r = row[i]
+            next_r = row[i + 1]
+            # if is_cross(r[0:2], r[2:4], next_r[0:2], next_r[2:4]):
+            if line_iou([r[0:2], r[2:4]], [next_r[0:2], next_r[2:4]], axis=0) >= 0.1:
+                temp_r = [min(temp_r[0], r[0], r[2], next_r[0], next_r[2]), row_h,
+                          max(temp_r[2], r[0], r[2], next_r[0], next_r[2]), row_h]
+            else:
+                new_line_list.append(temp_r)
+                # 合并横线后,与原来横线相交的竖线的坐标需更新
+                for index, col in enumerate(col_line_list):
+                    if temp_r[0] <= col[0] <= temp_r[2] and col[1] <= temp_r[1] <= col[3]:
+                        col_line_list[index] = [col[0], temp_r[1], col[2], temp_r[3]]
+                        new_line_list.append(col_line_list[index])
+                temp_r = next_r
+        if not new_line_list or (new_line_list and new_line_list[-1] != temp_r):
+            new_line_list.append(temp_r)
+
+    log('merge_line1 cost: ' + str(time.time()-start_time))
+    return new_line_list
+
+
 def remove_outline_no_cross(_line_list):
     start_time = time.time()
     row_list = []
@@ -417,7 +535,8 @@ def table_line_pdf(line_obj_list, layout, page_no, show=0):
                 new_y1 = 10 if lt_rect.y1 < 0 else lt_rect.y1
                 lt_rect.set_bbox((lt_rect.x0, new_y0, lt_rect.x1, new_y1))
 
-    _plot([x.bbox for x in lt_rect_list + lt_line_list], 'get_page_lines start', mode=2, show=show)
+    _plot([x.bbox for x in lt_rect_list], 'get_page_lines start lt_rect_list', mode=2, show=show)
+    _plot([x.bbox for x in lt_line_list], 'get_page_lines start lt_line_list', mode=2, show=show)
 
     # 合并矩形框
     # for i in range(len(non_line_rect_list)):
@@ -505,7 +624,7 @@ def table_line_pdf(line_obj_list, layout, page_no, show=0):
     # 从线对象提取线
     for lt_line in lt_line_list+lt_curve_list:
         _b = lt_line.bbox
-        if lt_line.height > 10 or lt_line.width > 10:
+        if lt_line.height > 5 or lt_line.width > 5:
             if lt_line.height >= lt_line.width:
                 line_list += [[_b[0], _b[1], _b[0], _b[3]], [_b[2], _b[1], _b[2], _b[3]]]
             else:
@@ -588,12 +707,17 @@ def table_line_pdf(line_obj_list, layout, page_no, show=0):
     line_list = list(set(line_list))
     line_list = [eval(x) for x in line_list]
 
+    if show:
+        print('line_list len(line_list)', len(line_list))
+    _plot(line_list, 'line_list', mode=2, show=show)
+
     # 合并线
     line_list = merge_line(line_list)
 
     if show:
-        print('get_page_lines len(line_list)', len(line_list))
-    _plot(line_list, 'line_list+bias_line_list', mode=2, show=show)
+        print('merge_line len(line_list)', len(line_list))
+        print(line_list)
+    _plot(line_list, 'merge_line', mode=2, show=show)
 
     # 根据是否有交点判断表格线
     cross_line_list = get_cross_line(line_list, threshold=2, cross_times=1)
@@ -640,4 +764,30 @@ def table_line_pdf(line_obj_list, layout, page_no, show=0):
         lt_line_list.append(LTLine(1, (float(line[0]), float(line[1])),
                                    (float(line[2]), float(line[3]))))
     log("pdf page %s has %s lines cost: %s" % (str(page_no), str(len(lt_line_list)), str(time.time()-start_time)))
-    return lt_line_list
+    return lt_line_list
+
+
+def two_line_cross(x1, y1, x2, y2):
+    """
+    暂时不用
+
+    :param x1:
+    :param y1:
+    :param x2:
+    :param y2:
+    :return:
+    """
+    def cross_product(p1, p2, p3):
+        return (p2[0] - p1[0]) * (p3[1] - p1[1]) - (p2[1] - p1[1]) * (p3[0] - p1[0])
+
+    # 计算叉乘
+    cross1 = cross_product(x1, y1, x2)
+    cross2 = cross_product(x1, y1, y2)
+    cross3 = cross_product(x2, y2, x1)
+    cross4 = cross_product(x2, y2, y1)
+
+    # 判断叉乘的符号
+    if (cross1 * cross2 < 0) and (cross3 * cross4 < 0):
+        return True
+    else:
+        return False