Ver Fonte

优化点到线的识别;修复边框线;修复未识别的格子

luojiehua há 3 anos atrás
pai
commit
ce8d25fa3d
4 ficheiros alterados com 213 adições e 31 exclusões
  1. 1 1
      format_convert/convert.py
  2. 145 24
      format_convert/utils.py
  3. 4 5
      otr/table_line.py
  4. 63 1
      result.html

+ 1 - 1
format_convert/convert.py

@@ -2650,7 +2650,7 @@ if __name__ == '__main__':
     if get_platform() == "Windows":
         file_path = "C:/Users/Administrator/Desktop/error6.jpg"
         # file_path = "D:/BIDI_DOC/比地_文档/2022/Test_Interface/has-3.rar"
-        # file_path = "C:/Users/Administrator/Desktop/Test_ODPS/1624875783055.pdf"
+        file_path = "files/无边框1.jpg"
     else:
         file_path = "1.doc"
 

+ 145 - 24
format_convert/utils.py

@@ -603,11 +603,25 @@ class LineTable:
                 list_tables.append(_ta)
         return list_tables,in_objs,list_l_rect
 
-    def recognize_crosspoints(self, list_line):
-        from matplotlib import pyplot as plt
+    def recognize_crosspoints(self, list_line,fixLine=True):
+
         list_crosspoints = []
         # print("lines num",len(list_line))
 
+        def getMaxPoints(list_x,margin=5):
+            clust_x = []
+            for _x in list_x:
+                _find = False
+                for cx in clust_x:
+                    if abs(cx[0]-_x)<margin:
+                        _find = True
+                        cx.append(_x)
+                        break
+                if not _find:
+                    clust_x.append([_x])
+            clust_x.sort(key=lambda x:len(x),reverse=True)
+            return clust_x[0][0],len(clust_x[0])
+
         for _i in range(len(list_line)):
             for _j in range(len(list_line)):
                 line1 = list_line[_i].__dict__.get("bbox")
@@ -616,16 +630,96 @@ class LineTable:
                 if exists:
                     list_crosspoints.append(point)
 
-        # plt.figure()
-        # for _line in list_line:
-        #     x0,y0,x1,y1 = _line.__dict__.get("bbox")
-        #     plt.plot([x0,x1],[y0,y1])
-        # for _line in list_line:
-        #     x0,y0,x1,y1 = _line.bbox
-        #     plt.plot([x0,x1],[y0,y1])
-        # for point in list_crosspoints:
-        #     plt.scatter(point.get("point")[0],point.get("point")[1])
-        # plt.show()
+
+        if fixLine:
+            #聚类
+            cluster_crosspoints = []
+            for _point in list_crosspoints:
+                cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
+            while 1:
+                _find = False
+                new_cluster_crosspoints = []
+                for l_point in cluster_crosspoints:
+                    _flag = False
+                    for l_n_point in new_cluster_crosspoints:
+                        line1 = l_point.get("lines")
+                        line2 = l_n_point.get("lines")
+                        if len(line1&line2)>0:
+                            _find = True
+                            _flag = True
+                            l_n_point["lines"] = line1.union(line2)
+                            l_n_point["points"].extend(l_point["points"])
+
+                    if not _flag:
+                        new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
+                cluster_crosspoints = new_cluster_crosspoints
+                if not _find:
+                    break
+
+            for list_cp in cluster_crosspoints:
+                points = list_cp.get("points")
+                l_lines = []
+                for p in points:
+                    l_lines.extend(p.get("p_lines"))
+                l_lines = list(set(l_lines))
+                l_lines.sort(key=lambda x:x[0])
+
+                min_x,_count = getMaxPoints([l[0] for l in l_lines])
+                if _count<2:
+                    min_x = None
+
+
+                min_y,_count = getMaxPoints([l[1] for l in l_lines])
+                if _count<2:
+                    min_y = None
+
+
+                max_x,_count = getMaxPoints([l[2] for l in l_lines])
+                if _count<2:
+                    max_x = None
+
+
+                max_y,_count = getMaxPoints([l[3] for l in l_lines])
+                if _count<2:
+                    max_y = None
+                if min_x and min_y and max_x and max_y:
+
+                    points.sort(key=lambda x:x["point"][0])
+                    if abs(min_x-points[0]["point"][0])>10:
+                        list_line.append(LTLine(1,(min_x,min_y),(min_x,max_y)))
+
+
+                    if abs(max_x-points[-1]["point"][0])>10:
+                        list_line.append(LTLine(1,(max_x,min_y),(max_x,max_y)))
+
+                    points.sort(key=lambda x:x["point"][1])
+                    if abs(min_y-points[0]["point"][1])>10:
+                        list_line.append(LTLine(1,(min_x,min_y),(max_x,min_y)))
+
+                    if abs(max_y-points[-1]["point"][1])>10:
+                        list_line.append(LTLine(1,(min_x,max_y),(max_x,max_y)))
+
+
+            list_crosspoints = []
+            for _i in range(len(list_line)):
+                for _j in range(len(list_line)):
+                    line1 = list_line[_i].__dict__.get("bbox")
+                    line2 = list_line[_j].__dict__.get("bbox")
+                    exists,point = self.cross_point(line1,line2)
+                    if exists:
+                        list_crosspoints.append(point)
+
+        from matplotlib import pyplot as plt
+        plt.figure()
+        for _line in list_line:
+            x0,y0,x1,y1 = _line.__dict__.get("bbox")
+            plt.plot([x0,x1],[y0,y1])
+        for _line in list_line:
+            x0,y0,x1,y1 = _line.bbox
+            plt.plot([x0,x1],[y0,y1])
+        for point in list_crosspoints:
+            plt.scatter(point.get("point")[0],point.get("point")[1])
+        plt.show()
 
         # print(list_crosspoints)
         # print("points num",len(list_crosspoints))
@@ -703,9 +797,10 @@ class LineTable:
                     _line = lines[1]
                 next_point = None
                 for p1 in  dict_line_points[_line]["points"]:
-                    if p1["buttom"]>=margin and p1["point"][0]>_point["point"][0]:
-                        next_point = p1
-                        break
+                    if p1["point"][0]>_point["point"][0]:
+                        if p1["buttom"]>=margin:
+                            next_point = p1
+                            break
                 if not next_point:
                     continue
                 lines = list(next_point.get("lines"))
@@ -714,14 +809,17 @@ class LineTable:
                     _line = lines[1]
                 final_point = None
                 for p1 in dict_line_points[_line]["points"]:
-                    if p1["left"]>=margin and p1["point"][1]>next_point["point"][1]:
-                        final_point = p1
-                        break
+                    if p1["point"][1]>next_point["point"][1]:
+                        if p1["left"]>=margin:
+                            final_point = p1
+                            break
                 if not final_point:
+                    next_point["buttom"] = 0
                     continue
                 _r = LTRect(1,(_point["point"][0],_point["point"][1],final_point["point"][0],final_point["point"][1]))
                 list_rect.append(_r)
 
+
         return list_rect
 
     def cross_point(self, line1, line2, segment=True, margin=2):
@@ -777,7 +875,7 @@ class LineTable:
         line1_key = "%.2f-%.2f-%.2f-%.2f"%(x1, y1, x2, y2)
         line2_key = "%.2f-%.2f-%.2f-%.2f"%(x3, y3, x4, y4)
         return point_is_exist, {"point": [x, y], "left": left, "right": right,
-                                "top": top, "buttom": buttom, "lines": set([line1_key,line2_key])}
+                                "top": top, "buttom": buttom, "lines": set([line1_key,line2_key]),"p_lines":[line1,line2]}
 
     def unionTable(self, list_table, fixspan=True, margin=2):
         set_x = set()
@@ -874,7 +972,7 @@ class LineTable:
         ta = {"bbox":table_bbox,"table":_table}
         return ta
 
-    def rect2table(self, list_textbox, list_rect, in_objs, margin=0.2, fixspan=True,sourceP_LB=True):
+    def rect2table(self, list_textbox, list_rect, in_objs, margin=0.2, fixspan=True,sourceP_LB=True,fixRect=True):
         _table = []
         set_x = set()
         set_y = set()
@@ -893,10 +991,7 @@ class LineTable:
             if not _find:
                 clusters_rects.append([_rect])
 
-        print("clusters_rects", len(clusters_rects))
-        clusters_rects.sort(key=lambda x:x[0].bbox[3],reverse=sourceP_LB)
-        for l_cr in clusters_rects:
-            l_cr.sort(key=lambda x:x.bbox[0])
+
 
         # cul spans
         for _line in clusters_rects:
@@ -914,6 +1009,32 @@ class LineTable:
         list_x.sort(key=lambda x:x)
         list_y.sort(key=lambda x:x,reverse=sourceP_LB)
 
+        print("clusters_rects", len(clusters_rects))
+        clusters_rects.sort(key=lambda x:x[0].bbox[3],reverse=sourceP_LB)
+        for l_cr in clusters_rects:
+            l_cr.sort(key=lambda x:x.bbox[0])
+            if fixRect:
+                extend_cr = []
+                for cr_i in range(len(l_cr)):
+                    if cr_i==0:
+                        if abs(l_cr[cr_i].bbox[0]-list_x[0])>margin:
+                            extend_cr.append(LTRect(1,[list_x[0],l_cr[cr_i].bbox[1],l_cr[cr_i].bbox[0],l_cr[cr_i].bbox[3]]))
+
+                    if cr_i>=0 and cr_i<len(l_cr)-1:
+                        if abs(l_cr[cr_i].bbox[2]-l_cr[cr_i+1].bbox[0])>margin:
+                            extend_cr.append(LTRect(1,[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],l_cr[cr_i+1].bbox[2],l_cr[cr_i].bbox[3]]))
+
+                    if cr_i==len(l_cr)-1:
+                        if abs(l_cr[cr_i].bbox[2]-list_x[-1])>margin:
+                            print("=====extend",[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],list_x[-1],l_cr[cr_i].bbox[3]])
+                            extend_cr.append(LTRect(1,[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],list_x[-1],l_cr[cr_i].bbox[3]]))
+
+                if extend_cr:
+                    l_cr.extend(extend_cr)
+                l_cr.sort(key=lambda x:x.bbox[0])
+
+
+
         pop_x = []
         for i in range(len(list_x)-1):
             _i = len(list_x)-i-1

+ 4 - 5
otr/table_line.py

@@ -223,6 +223,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
         if x>=bbox[0]-line_width and x<=bbox[2]+line_width and y>=bbox[1]-line_width and y<=bbox[3]+line_width:
             return True,[min(x,bbox[0]),min(y,bbox[1]),max(x,bbox[2]),max(y,bbox[3])]
         return False,None
+    _time = time.time()
     height = len(pred)
     width = len(pred[0])
 
@@ -234,7 +235,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
     _last = False
     _current = False
     while 1:
-        h_index += 1
+        h_index += 2
         if h_index>=height:
             break
         w_index = -1
@@ -269,7 +270,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
     w_index = -1
     _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1))
     while 1:
-        w_index += 1
+        w_index += 2
         if w_index>=width:
             break
         h_index = -1
@@ -302,9 +303,6 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
                 if not _find:
                     clust_horizontal.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
 
-
-
-
     tmp_vertical = []
     for _dict in clust_vertical:
         _bbox = _dict.get("bbox")
@@ -353,6 +351,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
     for _line in final_horizontal:
         list_line.append(_line)
 
+    logging.info("points2lines cost %.2fs"%(time.time()-_time))
 
     import matplotlib.pyplot as plt
     plt.figure()

+ 63 - 1
result.html

@@ -1 +1,63 @@
-<!DOCTYPE HTML><head><meta charset="UTF-8"></head><body></body>
+<!DOCTYPE HTML><head><meta charset="UTF-8"></head><body><table border="1">
+<tr>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1>ghfgh</td>
+<td colspan=1 rowspan=1></td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1>Ssd</td>
+<td colspan=1 rowspan=1>2sdsd</td>
+<td colspan=1 rowspan=1>0000</td>
+<td colspan=1 rowspan=1>sadasf</td>
+<td colspan=1 rowspan=1>Sad</td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1>Asdsa</td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1>Dasdghf</td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1>khjfgj</td>
+<td colspan=1 rowspan=1></td>
+</tr>
+<tr>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1>asdasd</td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1>erer</td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+<td colspan=1 rowspan=1></td>
+</tr>
+</table>
+</body>