Explorar el Código

完整表格修复的逻辑,增加位置校验,解决修复之后一行超出范围的问题。对pdf中图片和文字都存在的情况的逻辑进行修复

luojiehua hace 2 años
padre
commit
7787cd254d
Se han modificado 3 ficheros con 123 adiciones y 86 borrados
  1. 62 61
      format_convert/convert_pdf.py
  2. 60 24
      format_convert/utils.py
  3. 1 1
      otr/table_line.py

+ 62 - 61
format_convert/convert_pdf.py

@@ -772,7 +772,54 @@ class PDFConvert:
         return re.sub("\s","",_text)
 
 
-
+    def get_text_lines(self,page,page_no):
+        lt_line_list = []
+        page_plumber = pdfPage(self.doc_pdfplumber, page, page_number=page_no, initial_doctop=self.doc_top)
+        self.doc_top += page_plumber.height
+
+        table_finder = TableFinder(page_plumber)
+        for _edge in table_finder.get_edges():
+            lt_line_list.append(LTLine(1, (float(_edge["x0"]), float(_edge["y0"])),
+                                       (float(_edge["x1"]), float(_edge["y1"]))))
+        log("pdf page %s has %s lines"%(str(page_no),str(len(lt_line_list))))
+        return lt_line_list
+
+    def recognize_text(self,layout,page_no,lt_text_list,lt_line_list):
+        list_tables, filter_objs, _ = self.lt.recognize_table(lt_text_list, lt_line_list)
+        self._page.in_table_objs = filter_objs
+
+        print("=======text_len:%d:filter_len:%d"%(len(lt_text_list),len(filter_objs)))
+
+        for table in list_tables:
+            _table = _Table(table["table"], table["bbox"])
+            # self._page.children.append(_table)
+            self._page.add_child(_table)
+
+        list_sentences = ParseUtils.recognize_sentences(lt_text_list, filter_objs,
+                                                        layout.bbox, page_no)
+
+        for sentence in list_sentences:
+            _sen = _Sentence(sentence.text, sentence.bbox)
+            self._page.add_child(_sen)
+        # pdf对象需反向排序
+        self._page.is_reverse = True
+
+    def is_text_legal(self,lt_text_list,page_no):
+        # 无法识别pdf字符编码,整页用ocr
+        text_temp = ""
+        for _t in lt_text_list:
+            text_temp += _t.get_text()
+
+        if re.search('[(]cid:[0-9]+[)]', text_temp):
+            log("text has cid! try pymupdf...")
+            page_image = self.get_page_image(page_no)
+            if judge_error_code(page_image):
+                self._page.error_code = page_image
+            else:
+                _image = _Image(page_image[1], page_image[0])
+                self._page.add_child(_image)
+            return False
+        return True
 
     def convert_page(self, page, page_no):
         # pdf page.annots为None,不经过get_layout,直接ocr
@@ -784,6 +831,8 @@ class PDFConvert:
         #     image_count = 1
         # else:
         layout = self.get_layout(page, page_no)
+        self.width = layout.width
+        self.height = layout.height
         if self._doc.error_code is not None:
             return
         if judge_error_code(layout):
@@ -826,51 +875,14 @@ class PDFConvert:
                     self._page.add_child(_image)
                 return
 
-            # 无法识别pdf字符编码,整页用ocr
-            text_temp = ""
-            for _t in lt_text_list:
-                text_temp += _t.get_text()
 
-            if re.search('[(]cid:[0-9]+[)]', text_temp):
-                log("text has cid! try pymupdf...")
-                page_image = self.get_page_image(page_no)
-                if judge_error_code(page_image):
-                    self._page.error_code = page_image
-                else:
-                    _image = _Image(page_image[1], page_image[0])
-                    self._page.add_child(_image)
+            if not self.is_text_legal(lt_text_list,page_no):
                 return
 
             try:
-                lt_line_list = []
-                page_plumber = pdfPage(self.doc_pdfplumber, page, page_number=page_no, initial_doctop=self.doc_top)
-                self.doc_top += page_plumber.height
-
-                table_finder = TableFinder(page_plumber)
-                for _edge in table_finder.get_edges():
-                    lt_line_list.append(LTLine(1, (float(_edge["x0"]), float(_edge["y0"])),
-                                                  (float(_edge["x1"]), float(_edge["y1"]))))
-
-
-                #draw lines to check
-                # draw_lines_plt([l.bbox for l in lt_line_list])
-
-                list_tables, filter_objs, _ = self.lt.recognize_table(lt_text_list, lt_line_list)
-                self._page.in_table_objs = filter_objs
-
-                for table in list_tables:
-                    _table = _Table(table["table"], table["bbox"])
-                    # self._page.children.append(_table)
-                    self._page.add_child(_table)
-
-                list_sentences = ParseUtils.recognize_sentences(lt_text_list, filter_objs,
-                                                                layout.bbox, page_no)
+                lt_line_list = self.get_text_lines(page,page_no)
+                self.recognize_text(layout,page_no,lt_text_list,lt_line_list)
 
-                for sentence in list_sentences:
-                    _sen = _Sentence(sentence.text, sentence.bbox)
-                    self._page.add_child(_sen)
-                # pdf对象需反向排序
-                self._page.is_reverse = True
             except:
                 traceback.print_exc()
                 self._page.error_code = [-8]
@@ -888,25 +900,6 @@ class PDFConvert:
 
         # 正常读取该页对象
         else:
-            # 文本对象
-            for x in lt_text_list:
-                # 获取对象文本
-                object_text = x.get_text()
-
-                # 无法识别pdf字符编码,整页用ocr
-                if re.search('[(]cid:[0-9]+[)]', object_text):
-                    page_image = self.get_page_image(page_no)
-                    if judge_error_code(page_image):
-                        self._page.error_code = page_image
-                    else:
-                        _image = _Image(page_image[1], page_image[0])
-                        self._page.add_child(_image)
-                    return
-                else:
-                    _sen = _Sentence(object_text, x.bbox)
-                    # _sen.x = x.bbox[0]
-                    # _sen.y = x.bbox[1]
-                    self._page.add_child(_sen)
 
             # 图表对象
             for image in lt_image_list:
@@ -918,7 +911,8 @@ class PDFConvert:
                         continue
                     # 查看提取的图片高宽,太大则用pdf输出图进行ocr识别
                     img_test = Image.open(io.BytesIO(image_stream))
-                    if img_test.size[1] > 2000 or img_test.size[0] > 1500:
+                    # img_test.show()
+                    if image.height >= self.height-100 and image.width >= self.width-100:
                         print("pdf2text LTImage stream output size", img_test.size)
                         page_image = self.get_page_image(page_no)
                         if judge_error_code(page_image):
@@ -943,6 +937,13 @@ class PDFConvert:
                     print(traceback.print_exc())
             # pdf对象需反向排序
             self._page.is_reverse = True
+            self.init_package("pdfplumber")
+
+            if not self.is_text_legal(lt_text_list,page_no):
+                return
+
+            lt_line_list = self.get_text_lines(page,page_no)
+            self.recognize_text(layout,page_no,lt_text_list,lt_line_list)
 
     def get_layout(self, page, page_no):
         log("")

+ 60 - 24
format_convert/utils.py

@@ -585,7 +585,7 @@ class LineTable:
             _ta = self.rect2table(list_textbox,l_rect,in_objs,sourceP_LB=sourceP_LB)
             if _ta:
                 list_tables.append(_ta)
-        # self._plot(list_line, list_textbox)
+        self._plot(list_line, list_textbox)
         return list_tables, in_objs, list_l_rect
 
     def recognize_table_by_rect(self, list_textbox, list_rect, margin=2):
@@ -1077,7 +1077,35 @@ class LineTable:
                 list_location.append(_x)
         return list_location
 
-    def fixSpan(self,_table,list_x,list_y):
+    def fixSpan(self,_table,list_x,list_y,sourceP_LB):
+
+        def checkPosition(_line,_position,bbox,margin=5):
+            #check y
+            if len(_line)>0:
+                _bbox = _line[0].get("bbox")
+                print(1,_bbox)
+                print(2,bbox)
+                if abs(min(_bbox[1],_bbox[3])-min(bbox[1],bbox[3]))>margin or abs(max(_bbox[1],_bbox[3])-max(bbox[1],bbox[3]))>margin:
+                    print("check position y false")
+                    return False
+            #check x
+            if _position<=len(_line)-1:
+                after_bbox = _line[_position].get("bbox")
+                # the insert bbox.x1 should not less then the after bbox.x0
+                if not (after_bbox[0]>=bbox[2]):
+                    print("check position x after false")
+                    return False
+            if _position-1>0 and _position-1<len(_line):
+                before_bbox = _line[_position-1].get("bbox")
+                # the insert bbox.x1 should less equal than the first bbox.x0
+                if not (bbox[0]>=before_bbox[2]):
+                    print("check position x before false")
+                    return False
+            return True
+
+
+
+        #拓展columnspan的数据
         for _line in _table:
             c_i = 0
             while c_i<len(_line):
@@ -1094,14 +1122,18 @@ class LineTable:
                         _cell["columnspan"] = 1
 
                         #len(locations)==_colspan+1
+
                         for i in range(1,_cospan):
                             n_cell = {}
                             n_cell.update(_cell)
                             n_cell["bbox"] = (locations[i],y0,locations[i+1],y1)
                             c_i += 1
-                            _line.insert(c_i,n_cell)
+                            #check the position
+                            if checkPosition(_line,c_i,n_cell["bbox"]):
+                                _line.insert(c_i,n_cell)
 
                 c_i += 1
+        #拓展rowspan的数据
         for l_i in range(len(_table)):
             _line = _table[l_i]
             c_i = 0
@@ -1122,13 +1154,13 @@ class LineTable:
                             if l_i+i<=len(_table)-1:
                                 # print(len(_table),l_i+i)
                                 n_cell["bbox"] = (x0,locations[i],x1,locations[i+1])
-                                _table[l_i+i].insert(c_i,n_cell)
-
+                                if checkPosition(_table[l_i+i],c_i,n_cell["bbox"]):
+                                    _table[l_i+i].insert(c_i,n_cell)
                 c_i += 1
 
 
     def fixRect(self,_table,list_x,list_y,sourceP_LB,margin):
-        self.fixSpan(_table,list_x,list_y)
+        self.fixSpan(_table,list_x,list_y,sourceP_LB)
         # for line_i in range(len(_table)):
         #     for cell_i in range(len(_table[line_i])):
         #         _cell = _table[line_i][cell_i]
@@ -1250,6 +1282,7 @@ class LineTable:
             list_rect.sort(key=lambda x:x.bbox[3])
             for _rect in list_rect:
                 _y0 = _rect.bbox[3]
+                _y1 = _rect.bbox[1]
                 _find = False
                 for l_cr in clusters_rects:
                     if abs(l_cr[0].bbox[3]-_y0)<margin:
@@ -1262,6 +1295,7 @@ class LineTable:
             list_rect.sort(key=lambda x:x.bbox[1])
             for _rect in list_rect:
                 _y0 = _rect.bbox[1]
+                _y1 = _rect.bbox[3]
                 _find = False
                 for l_cr in clusters_rects:
                     if abs(l_cr[0].bbox[1]-_y0)<margin:
@@ -1299,9 +1333,8 @@ class LineTable:
 
         # print("clusters_rects", len(clusters_rects))
         if sourceP_LB:
-            clusters_rects.sort(key=lambda x:x[0].bbox[3],reverse=sourceP_LB)
-        else:
-            clusters_rects.sort(key=lambda x:x[0].bbox[1],reverse=sourceP_LB)
+            clusters_rects.sort(key=lambda x:(x[0].bbox[1]+x[0].bbox[3])/2,reverse=sourceP_LB)
+        clusters_rects.sort(key=lambda x:(x[0].bbox[1]+x[0].bbox[3])/2,reverse=sourceP_LB)
 
         for l_cr in clusters_rects:
             l_cr.sort(key=lambda x:x.bbox[0])
@@ -1326,8 +1359,8 @@ class LineTable:
         for _x in pop_x:
             list_y.pop(_x)
 
-        # print(list_x)
-        # print(list_y)
+        print("list_x",list_x)
+        print("list_y",list_y)
         line_i = 0
         for _line in clusters_rects:
             table_line = []
@@ -1361,24 +1394,27 @@ class LineTable:
         if _table is None:
             return
 
-
         self.feedText2table(_table,list_textbox,in_objs,sourceP_LB)
+
+
+        print("table===========================>")
+        for _line in _table:
+            for _cell in _line:
+                print("||%d%d"%(_cell["rowspan"],_cell["columnspan"]),end="\t")
+            print()
+        print("table===========================>")
+
+        print("------------")
+        for _line in _table:
+            for _cell in _line:
+                print(_cell["text"],end="\t")
+            print("\n")
+        print("------------")
+
         self.fixRect(_table,list_x,list_y,sourceP_LB,margin)
         self.feedText2table(_table,list_textbox,in_objs,sourceP_LB)
 
-        # print("table===========================>")
-        # for _line in _table:
-        #     for _cell in _line:
-        #         print("||%d%d"%(_cell["rowspan"],_cell["columnspan"]),end="\t")
-        #     print()
-        # print("table===========================>")
 
-        # print("------------")
-        # for _line in _table:
-        #     for _cell in _line:
-        #         print(_cell["text"])
-        #         print("\n")
-        # print("------------")
 
         table_bbox = (_table[0][0].get("bbox")[0],
                       _table[0][0].get("bbox")[1],

+ 1 - 1
otr/table_line.py

@@ -920,7 +920,7 @@ def table_postprocess(img_new, pred, prob=0.2, is_test=0):
         return [-1]
 
 
-def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
+def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0:
     log("into table_line, prob is " + str(prob))
     sizew, sizeh = size
     img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)