Bladeren bron

优化图片表格后处理

fangjiasheng 1 jaar geleden
bovenliggende
commit
7d86b18c35
4 gewijzigde bestanden met toevoegingen van 72 en 12 verwijderingen
  1. 4 4
      format_convert/convert_test.py
  2. 62 4
      format_convert/utils.py
  3. 1 0
      ocr/ocr_interface.py
  4. 5 4
      otr/table_line_new.py

+ 4 - 4
format_convert/convert_test.py

@@ -36,11 +36,11 @@ def test_one(p, page_no_range=None, from_remote=False, timeout=300, save_middle=
     data = {"file": file_base64, "type": p.split(".")[-1], "filemd5": _md5, 'page_no': page_no_range,
             'timeout': timeout, 'save_middle': save_middle}
     if from_remote:
-        _url = 'http://121.46.18.113:15010/convert'
+        # _url = 'http://121.46.18.113:15010/convert'
         # _url = 'http://192.168.2.103:15010/convert'
         # _url = 'http://192.168.2.102:15011/convert'
         # _url = 'http://172.16.160.65:15010/convert'
-        # _url = 'http://127.0.0.1:15010/convert'
+        _url = 'http://127.0.0.1:15010/convert'
         result = json.loads(request_post(_url, data, time_out=timeout+20))
         text_str = ""
         for t in result.get("result_html"):
@@ -119,9 +119,9 @@ if __name__ == '__main__':
         # file_path = "C:/Users/Administrator/Desktop/test_doc/error5.doc"
         # file_path = "D:/BIDI_DOC/比地_文档/1677829036789.pdf"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_ODPS/1624325845476.pdf"
-        # file_path = "C:/Users/Administrator/Downloads/W020230512399773694376.jpg"
+        file_path = "C:/Users/Administrator/Downloads/d871aa30916ab23c7d91d34ebd40002a.jpg"
         # file_path = "C:/Users/Administrator/Desktop/test_doc/error14.docx"
-        file_path = "C:/Users/Administrator/Desktop/test_image/error9-1.png"
+        # file_path = "C:/Users/Administrator/Desktop/test_image/error9-1.png"
         # file_path = "C:/Users/Administrator/Desktop/test_b_table/error1.png"
         # file_path = "C:/Users/Administrator/Desktop/test_pdf/直接读表格线error/error62.pdf"
         # file_path = "C:/save_b_table/0-0895e32470613dd7be1139eefd1342c4.png"

+ 62 - 4
format_convert/utils.py

@@ -5,6 +5,7 @@ import hashlib
 import inspect
 import json
 import os
+import pickle
 import socket
 import subprocess
 import sys
@@ -337,12 +338,14 @@ def slash_replace(_str, reverse=False):
 
 
 class LineTable:
-    def recognize_table(self, list_textbox, list_line, sourceP_LB=True, splited=False, from_pdf=False):
+    def recognize_table(self, list_textbox, list_line, sourceP_LB=True,
+                        splited=False, from_pdf=False, show=0):
         self.list_line = list_line
         self.list_crosspoints = self.recognize_crosspoints(list_line)
         self.from_pdf = from_pdf
         self.splited = splited
         self.connect_bbox_list = []
+        self.show = show
 
         # 聚类
         cluster_crosspoints = []
@@ -392,8 +395,24 @@ class LineTable:
                 return [], [], [], self.connect_bbox_list
             if _ta:
                 list_tables.append(_ta)
-        # 展示表格及文字
-        # self._plot(list_line, list_textbox)
+
+        if self.show:
+            # 展示原始表格及文字
+            self._plot(list_line, list_textbox, title='list_line,list_textbox')
+
+            # 打印单元格
+            for list_rect in list_l_rect:
+                for rect in list_rect:
+                    print('rect', rect)
+                self._plot([], [], list_rect, title='list_l_rect')
+
+            # 打印最终表格
+            for table in list_tables:
+                table = table.get('table')
+                for row in table:
+                    print('------ row ------')
+                    for col in row:
+                        print('col', col)
         return list_tables, in_objs, list_l_rect, []
 
     # def recognize_table_by_rect(self, list_textbox, list_rect, margin=2):
@@ -891,6 +910,8 @@ class LineTable:
         return list_location
 
     def fixSpan(self, _table, list_x, list_y, sourceP_LB):
+        # with open('table.pickle', 'wb') as f:
+        #     pickle.dump(_table, f)
 
         def checkPosition(_line, _position, bbox, margin=5):
             # check y
@@ -1021,6 +1042,8 @@ class LineTable:
 
             for _tmp in extend_line:
                 _line.insert(_tmp["index"], _tmp["cell"])
+        # 排序
+        _table.sort(key=lambda x: (x[0].get('bbox')[1], x[0].get('bbox')[3]))
 
     def feedText2table(self, _table, list_textbox, in_objs, sourceP_LB):
 
@@ -1096,6 +1119,10 @@ class LineTable:
                 for box in box_list:
                     _cell["text"] += re.sub("\s", '', box[0])
 
+        # 打印所有cell
+        # for _cell in list_cells:
+        #     print("cell", _cell)
+
     def makeTableByRect(self, list_rect, margin, sourceP_LB):
         _table = []
         set_x = set()
@@ -1216,6 +1243,15 @@ class LineTable:
 
         _table, list_x, list_y = self.makeTableByRect(list_rect, margin, sourceP_LB)
 
+        if self.show:
+            # 打印_table
+            temp_list = []
+            for t in _table:
+                for c in t:
+                    print(c)
+                    temp_list.append(c)
+            self._plot([], [], temp_list, title='makeTableByRect table')
+
         if _table is None:
             return
 
@@ -1235,8 +1271,18 @@ class LineTable:
         #     print("\n")
         # print("------------")
 
+        _table.sort(key=lambda x: (x[0].get('bbox')[1], x[0].get('bbox')[3]))
         self.fixRect(_table, list_x, list_y, sourceP_LB, margin)
 
+        if self.show:
+            # 打印_table
+            temp_list = []
+            for t in _table:
+                for c in t:
+                    print(c)
+                    temp_list.append(c)
+            self._plot([], [], temp_list, title='fixRect table')
+
         # print("table===========================>")
         # for _line in _table:
         #     for _cell in _line:
@@ -1320,7 +1366,7 @@ class LineTable:
                 _count += 1
         return _count - 1
 
-    def _plot(self, list_line, list_textbox):
+    def _plot(self, list_line, list_textbox, list_rect=[], title=''):
         from matplotlib import pyplot as plt
         plt.figure()
         for _line in list_line:
@@ -1334,6 +1380,18 @@ class LineTable:
         for textbox in list_textbox:
             x0, y0, x1, y1 = textbox.bbox
             plt.plot([x0, x1], [y0, y1])
+
+        for rect in list_rect:
+            try:
+                x0, y0, x1, y1 = rect.bbox
+            except:
+                x0, y0, x1, y1 = rect.get("bbox")
+            plt.plot([x0, x0], [y0, y1])
+            plt.plot([x0, x1], [y0, y0])
+            plt.plot([x1, x1], [y0, y1])
+            plt.plot([x0, x1], [y1, y1])
+
+        plt.title(str(title))
         plt.show()
 
 

+ 1 - 0
ocr/ocr_interface.py

@@ -75,6 +75,7 @@ def picture2text(img_data, ocr_model, only_rec=0):
     try:
         # 二进制数据流转np.ndarray [np.uint8: 8位像素]
         img = bytes2np(img_data)
+        # cv2.imwrite('ocr.jpg', img)
 
         # 预测
         if only_rec:

+ 5 - 4
otr/table_line_new.py

@@ -861,7 +861,7 @@ def fix_inner(row_line_list, col_line_list, point_list):
                                 line2[1 - axis + 2]:
                             if line1[axis] <= line2[axis] <= line1[axis + 2]:
                                 continue
-                            min_distance[i] = abs(line1[axis] - line2[axis])
+                            min_distance[i] = abs(point[axis] - line2[axis])
                             min_assist_line[i] = line2
 
             if len(min_assist_line[0]) == 0 and len(min_assist_line[1]) == 0:
@@ -878,12 +878,13 @@ def fix_inner(row_line_list, col_line_list, point_list):
                             min_distance[i] = abs(point[axis] - min_assist_line[i][axis])
                             min_col_point[i] = point
 
-            # print("min_col_point", min_col_point)
-            # print("min_assist_line", min_assist_line)
-
             if len(min_col_point[0]) == 0 and len(min_col_point[1]) == 0:
                 continue
 
+            # print('line1', line1)
+            # print("min_col_point", min_col_point)
+            # print("min_assist_line", min_assist_line)
+
             # 顶点到交点的距离(多出来的线)需大于assist_line到交点的距离(bbox的边)的1/3
             # print("line1_point", line1_point)
             if min_assist_line[0] and min_assist_line[0] == min_assist_line[1]: