Kaynağa Gözat

优化表格文字\n的问题以及嵌套表格无法准确填入文字的问题

luojiehua 2 yıl önce
ebeveyn
işleme
ecb315f926

+ 5 - 0
format_convert/convert_image.py

@@ -61,11 +61,14 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
     def merge_textbox(textbox_list, in_objs):
         delete_obj = []
         threshold = 5
+        textbox_list.sort(key=lambda x:x.bbox[0])
         for k in range(len(textbox_list)):
             tb1 = textbox_list[k]
             if tb1 not in in_objs and tb1 not in delete_obj:
                 for m in range(k+1, len(textbox_list)):
                     tb2 = textbox_list[m]
+                    if tb2 in in_objs:
+                        continue
                     if abs(tb1.bbox[1]-tb2.bbox[1]) <= threshold \
                             and abs(tb1.bbox[3]-tb2.bbox[3]) <= threshold:
                         if tb1.bbox[0] <= tb2.bbox[0]:
@@ -222,6 +225,8 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
                 list_text_boxes.append(TextBox([bbox[0][0], bbox[0][1],
                                                 bbox[2][0], bbox[2][1]], b_text))
 
+            # for _textbox in list_text_boxes:
+            #     print("==",_textbox.get_text())
             lt = LineTable()
             tables, obj_in_table, _ = lt.recognize_table(list_text_boxes, list_lines, False)
 

+ 14 - 1
format_convert/convert_pdf.py

@@ -28,7 +28,7 @@ from pdfminer.converter import PDFPageAggregator
 from pdfminer.layout import LTTextBoxHorizontal, LAParams, LTFigure, LTImage, LTCurve, LTText, LTChar, LTRect, \
     LTTextBoxVertical, LTLine
 from format_convert.utils import judge_error_code, add_div, get_platform, get_html_p, string_similarity, LineTable, \
-    get_logger, log, memory_decorator
+    get_logger, log, memory_decorator,draw_lines_plt
 import fitz
 from format_convert.wrapt_timeout_decorator import timeout
 
@@ -766,6 +766,14 @@ class PDFConvert:
             self._doc.add_child(self._page)
             page_no += 1
 
+
+    def clean_text(self,_text):
+
+        return re.sub("\s","",_text)
+
+
+
+
     def convert_page(self, page, page_no):
         # pdf page.annots为None,不经过get_layout,直接ocr
         # if page.annots is None:
@@ -842,6 +850,11 @@ class PDFConvert:
                 for _edge in table_finder.get_edges():
                     lt_line_list.append(LTLine(1, (float(_edge["x0"]), float(_edge["y0"])),
                                                   (float(_edge["x1"]), float(_edge["y1"]))))
+
+
+                #draw lines to check
+                # draw_lines_plt([l.bbox for l in lt_line_list])
+
                 list_tables, filter_objs, _ = self.lt.recognize_table(lt_text_list, lt_line_list)
                 self._page.in_table_objs = filter_objs
 

+ 97 - 96
format_convert/utils.py

@@ -514,6 +514,14 @@ def judge_format(path):
     # 猜不到,返回None
     return None
 
+def draw_lines_plt(bboxes):
+    import matplotlib.pyplot as plt
+    plt.figure()
+    for bbox in bboxes:
+        x = [bbox[0],bbox[2]]
+        y = [bbox[1],bbox[3]]
+        plt.plot(x,y)
+    plt.show()
 
 def slash_replace(_str, reverse=False):
     if reverse:
@@ -551,6 +559,19 @@ class LineTable:
             if not _find:
                 break
 
+        #need to sort to deal with the inner tables
+        for clu_cp in cluster_crosspoints:
+            points = clu_cp["points"]
+            list_p = np.array([p["point"] for p in points])
+            max_x = max(list_p[...,0])
+            min_x = min(list_p[...,0])
+            max_y = max(list_p[...,1])
+            min_y = min(list_p[...,1])
+            _area = (max_y-min_y)*(max_x-min_x)
+            clu_cp["area"] = _area
+        cluster_crosspoints.sort(key=lambda x:x["area"])
+
+
         list_l_rect = []
         for table_crosspoint in cluster_crosspoints:
             list_rect = self.crosspoint2rect(table_crosspoint.get("points"))
@@ -1104,7 +1125,7 @@ class LineTable:
                 c_i += 1
 
 
-    def fixRect(self,_table,list_x,list_y,list_textbox,in_objs,sourceP_LB,margin):
+    def fixRect(self,_table,list_x,list_y,sourceP_LB,margin):
         self.fixSpan(_table,list_x,list_y)
         # for line_i in range(len(_table)):
         #     for cell_i in range(len(_table[line_i])):
@@ -1155,34 +1176,68 @@ class LineTable:
             for _tmp in extend_line:
                 _line.insert(_tmp["index"],_tmp["cell"])
 
-            list_textbox.sort(key=lambda x: x.bbox[0])
-            list_textbox.sort(key=lambda x: x.bbox[3], reverse=sourceP_LB)
-            for textbox in list_textbox:
-                if textbox in in_objs:
+    def feedText2table(self,_table,list_textbox,in_objs,sourceP_LB):
+
+        #find the suitable cell of the textbox
+        list_cells = []
+        for table_line in _table:
+            for _cell in table_line:
+                list_cells.append({"cell":_cell,"inbox_textbox_list":[]})
+
+        for textbox in list_textbox:
+            list_iou = []
+            for _d in list_cells:
+                _cell = _d["cell"]
+                _iou = self.getIOU(textbox.bbox,_cell["bbox"])
+                list_iou.append(_iou)
+            max_iou_index = np.argmax(list_iou)
+            max_iou = list_iou[max_iou_index]
+            if max_iou>0.1 and textbox not in in_objs:
+                list_cells[max_iou_index]["inbox_textbox_list"].append(textbox)
+                in_objs.add(textbox)
+
+
+        has_matched_box_list = []
+        for _d in list_cells:
+            _cell = _d["cell"]
+            inbox_textbox_list = _d["inbox_textbox_list"]
+
+            # 分行,根据y重合
+            all_match_box_list = []
+
+            inbox_textbox_list.sort(key=lambda x:x.bbox[1],reverse=sourceP_LB)
+            for i in range(len(inbox_textbox_list)):
+                match_box_list = []
+                box1 = inbox_textbox_list[i]
+                if box1 in has_matched_box_list:
                     continue
-                x0, y0, x1, y1 = textbox.bbox
-                _text = textbox.get_text()
-                _find = False
-                for table_line in _table:
-                    for _cell in table_line:
-                        if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
-                            _cell["text"] += _text
-                            in_objs.add(textbox)
-                            _find = True
-                            break
-                    if _find:
-                        break
 
-    def rect2table(self, list_textbox, list_rect, in_objs, margin=5, fixspan=False,sourceP_LB=True,fixRect=True):
+                min_y1 = box1.bbox[1] + 1/3 * abs(box1.bbox[3]-box1.bbox[1])
+                max_y1 = box1.bbox[3] - 1/3 * abs(box1.bbox[3]-box1.bbox[1])
+                match_box_list.append([box1.get_text(), box1.bbox[0], box1.bbox[1], box1.bbox[2], box1.bbox[3],min_y1,max_y1])
+                has_matched_box_list.append(box1)
+                for j in range(i+1, len(inbox_textbox_list)):
+                    box2 = inbox_textbox_list[j]
+                    if box2 in has_matched_box_list:
+                        continue
 
-        def getIOU(bbox0,bbox1):
-            width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
-            height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1])
-            if width<0 and height<0:
-                return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
-            return 0
+                    # print(min_y1, box2.bbox[1], box2.bbox[3], max_y1)
+                    # print(min_y2, box1.bbox[3], max_y2)
+                    if min_y1 <= box2.bbox[1] <= max_y1 or \
+                            min_y1 <= box2.bbox[3] <= max_y1 or \
+                            box2.bbox[1] <= min_y1 <= max_y1 <= box2.bbox[3]:
+                        match_box_list.append([box2.get_text(), box2.bbox[0], box2.bbox[1], box2.bbox[2], box2.bbox[3],min_y1,max_y1])
+                        has_matched_box_list.append(box2)
+                match_box_list.sort(key=lambda x: x[1])
+                all_match_box_list.append(match_box_list)
 
+            # print("match_box_list", all_match_box_list)
+            all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2,reverse=sourceP_LB)
+            for box_list in all_match_box_list:
+                for box in box_list:
+                    _cell["text"] += re.sub("\s",'',box[0])
 
+    def makeTableByRect(self,list_rect,margin,sourceP_LB):
         _table = []
         set_x = set()
         set_y = set()
@@ -1287,82 +1342,31 @@ class LineTable:
                 table_line.append(_cell)
             line_i += 1
             _table.append(table_line)
+        return _table,list_x,list_y
 
-        # print("table===========================>")
-        # for _line in _table:
-        #     for _cell in _line:
-        #         print("||%d%d"%(_cell["rowspan"],_cell["columnspan"]),end="\t")
-        #     print()
-        # print("table===========================>")
-
-        list_textbox.sort(key=lambda x:x.bbox[0])
-        list_textbox.sort(key=lambda x:x.bbox[3],reverse=sourceP_LB)
-        # print("list_textbox", list_textbox)
+    def rect2table(self, list_textbox, list_rect, in_objs, margin=5, sourceP_LB=True):
 
+        def getIOU(bbox0,bbox1):
+            width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
+            height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1])
+            if width<0 and height<0:
+                return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
+            return 0
 
-        #find the suitable cell of the textbox
-        list_cells = []
-        for table_line in _table:
-            for _cell in table_line:
-                list_cells.append({"cell":_cell,"inbox_textbox_list":[]})
 
-        for textbox in list_textbox:
-            list_iou = []
-            for _d in list_cells:
-                _cell = _d["cell"]
-                _iou = self.getIOU(textbox.bbox,_cell["bbox"])
-                list_iou.append(_iou)
-            max_iou_index = np.argmax(list_iou)
-            max_iou = list_iou[max_iou_index]
-            if max_iou>0.1:
-                list_cells[max_iou_index]["inbox_textbox_list"].append(textbox)
-                in_objs.add(textbox)
+        _table,list_x,list_y = self.makeTableByRect(list_rect,margin,sourceP_LB)
 
 
-        has_matched_box_list = []
-        for _d in list_cells:
-            _cell = _d["cell"]
-            inbox_textbox_list = _d["inbox_textbox_list"]
+        self.feedText2table(_table,list_textbox,in_objs,sourceP_LB)
+        self.fixRect(_table,list_x,list_y,sourceP_LB,margin)
+        self.feedText2table(_table,list_textbox,in_objs,sourceP_LB)
 
-            # 分行,根据y重合
-            all_match_box_list = []
-            if sourceP_LB:
-                inbox_textbox_list.sort(key=lambda x:x.bbox[1],reverse=True)
-            else:
-                inbox_textbox_list.sort(key=lambda x:x.bbox[1])
-            for i in range(len(inbox_textbox_list)):
-                match_box_list = []
-                box1 = inbox_textbox_list[i]
-                if box1 in has_matched_box_list:
-                    continue
-
-                min_y1 = box1.bbox[1] + 1/3 * abs(box1.bbox[3]-box1.bbox[1])
-                max_y1 = box1.bbox[3] - 1/3 * abs(box1.bbox[3]-box1.bbox[1])
-                match_box_list.append([box1.get_text(), box1.bbox[0], box1.bbox[1], box1.bbox[2], box1.bbox[3],min_y1,max_y1])
-                has_matched_box_list.append(box1)
-                for j in range(i+1, len(inbox_textbox_list)):
-                    box2 = inbox_textbox_list[j]
-                    if box2 in has_matched_box_list:
-                        continue
-
-                    # print(min_y1, box2.bbox[1], box2.bbox[3], max_y1)
-                    # print(min_y2, box1.bbox[3], max_y2)
-                    if min_y1 <= box2.bbox[1] <= max_y1 or \
-                            min_y1 <= box2.bbox[3] <= max_y1 or \
-                            box2.bbox[1] <= min_y1 <= max_y1 <= box2.bbox[3]:
-                        match_box_list.append([box2.get_text(), box2.bbox[0], box2.bbox[1], box2.bbox[2], box2.bbox[3],min_y1,max_y1])
-                        has_matched_box_list.append(box2)
-                match_box_list.sort(key=lambda x: x[1])
-                all_match_box_list.append(match_box_list)
-
-            # print("match_box_list", all_match_box_list)
-            if sourceP_LB:
-                all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2,reverse=True)
-            else:
-                all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2)
-            for box_list in all_match_box_list:
-                for box in box_list:
-                    _cell["text"] += box[0]
+        # print("table===========================>")
+        # for _line in _table:
+        #     for _cell in _line:
+        #         print("||%d%d"%(_cell["rowspan"],_cell["columnspan"]),end="\t")
+        #     print()
+        # print("table===========================>")
 
         # print("------------")
         # for _line in _table:
@@ -1371,9 +1375,6 @@ class LineTable:
         #         print("\n")
         # print("------------")
 
-        self.fixRect(_table,list_x,list_y,list_textbox,in_objs,sourceP_LB,margin)
-
-
         table_bbox = (_table[0][0].get("bbox")[0],
                       _table[0][0].get("bbox")[1],
                       _table[-1][-1].get("bbox")[2],