luojiehua 3 лет назад
Родитель
Сommit
57957f2350
1 измененных файлов с 143 добавлено и 16 удалено
  1. 143 16
      format_convert/utils.py

+ 143 - 16
format_convert/utils.py

@@ -603,11 +603,25 @@ class LineTable:
                 list_tables.append(_ta)
         return list_tables,in_objs,list_l_rect
 
-    def recognize_crosspoints(self, list_line):
-        from matplotlib import pyplot as plt
+    def recognize_crosspoints(self, list_line,fixLine=True):
+
         list_crosspoints = []
         # print("lines num",len(list_line))
 
+        def getMaxPoints(list_x,margin=5):
+            clust_x = []
+            for _x in list_x:
+                _find = False
+                for cx in clust_x:
+                    if abs(cx[0]-_x)<margin:
+                        _find = True
+                        cx.append(_x)
+                        break
+                if not _find:
+                    clust_x.append([_x])
+            clust_x.sort(key=lambda x:len(x),reverse=True)
+            return clust_x[0][0],len(clust_x[0])
+
         for _i in range(len(list_line)):
             for _j in range(len(list_line)):
                 line1 = list_line[_i].__dict__.get("bbox")
@@ -616,6 +630,85 @@ class LineTable:
                 if exists:
                     list_crosspoints.append(point)
 
+
+        if fixLine:
+            #聚类
+            cluster_crosspoints = []
+            for _point in list_crosspoints:
+                cluster_crosspoints.append({"lines":_point.get("lines"),"points":[_point]})
+            while 1:
+                _find = False
+                new_cluster_crosspoints = []
+                for l_point in cluster_crosspoints:
+                    _flag = False
+                    for l_n_point in new_cluster_crosspoints:
+                        line1 = l_point.get("lines")
+                        line2 = l_n_point.get("lines")
+                        if len(line1&line2)>0:
+                            _find = True
+                            _flag = True
+                            l_n_point["lines"] = line1.union(line2)
+                            l_n_point["points"].extend(l_point["points"])
+
+                    if not _flag:
+                        new_cluster_crosspoints.append({"lines":l_point.get("lines"),"points":l_point.get("points")})
+                cluster_crosspoints = new_cluster_crosspoints
+                if not _find:
+                    break
+
+            for list_cp in cluster_crosspoints:
+                points = list_cp.get("points")
+                l_lines = []
+                for p in points:
+                    l_lines.extend(p.get("p_lines"))
+                l_lines = list(set(l_lines))
+                l_lines.sort(key=lambda x:x[0])
+
+                min_x,_count = getMaxPoints([l[0] for l in l_lines])
+                if _count<2:
+                    min_x = None
+
+
+                min_y,_count = getMaxPoints([l[1] for l in l_lines])
+                if _count<2:
+                    min_y = None
+
+
+                max_x,_count = getMaxPoints([l[2] for l in l_lines])
+                if _count<2:
+                    max_x = None
+
+
+                max_y,_count = getMaxPoints([l[3] for l in l_lines])
+                if _count<2:
+                    max_y = None
+                if min_x and min_y and max_x and max_y:
+
+                    points.sort(key=lambda x:x["point"][0])
+                    if abs(min_x-points[0]["point"][0])>10:
+                        list_line.append(LTLine(1,(min_x,min_y),(min_x,max_y)))
+
+
+                    if abs(max_x-points[-1]["point"][0])>10:
+                        list_line.append(LTLine(1,(max_x,min_y),(max_x,max_y)))
+
+                    points.sort(key=lambda x:x["point"][1])
+                    if abs(min_y-points[0]["point"][1])>10:
+                        list_line.append(LTLine(1,(min_x,min_y),(max_x,min_y)))
+
+                    if abs(max_y-points[-1]["point"][1])>10:
+                        list_line.append(LTLine(1,(min_x,max_y),(max_x,max_y)))
+
+
+            list_crosspoints = []
+            for _i in range(len(list_line)):
+                for _j in range(len(list_line)):
+                    line1 = list_line[_i].__dict__.get("bbox")
+                    line2 = list_line[_j].__dict__.get("bbox")
+                    exists,point = self.cross_point(line1,line2)
+                    if exists:
+                        list_crosspoints.append(point)
+
         # plt.figure()
         # for _line in list_line:
         #     x0,y0,x1,y1 = _line.__dict__.get("bbox")
@@ -626,6 +719,17 @@ class LineTable:
         # for point in list_crosspoints:
         #     plt.scatter(point.get("point")[0],point.get("point")[1])
         # plt.show()
+        from matplotlib import pyplot as plt
+        plt.figure()
+        for _line in list_line:
+            x0,y0,x1,y1 = _line.__dict__.get("bbox")
+            plt.plot([x0,x1],[y0,y1])
+        for _line in list_line:
+            x0,y0,x1,y1 = _line.bbox
+            plt.plot([x0,x1],[y0,y1])
+        for point in list_crosspoints:
+            plt.scatter(point.get("point")[0],point.get("point")[1])
+        plt.show()
 
         # print(list_crosspoints)
         # print("points num",len(list_crosspoints))
@@ -722,6 +826,15 @@ class LineTable:
                 _r = LTRect(1,(_point["point"][0],_point["point"][1],final_point["point"][0],final_point["point"][1]))
                 list_rect.append(_r)
 
+        tmp_rect = []
+        set_bbox = set()
+        for _r in list_rect:
+            _bbox = "%.2f-%.2f-%.2f-%.2f"%_r.bbox
+            if _bbox not in set_bbox:
+                tmp_rect.append(_r)
+                set_bbox.add(_bbox)
+        list_rect = tmp_rect
+
         return list_rect
 
     def cross_point(self, line1, line2, segment=True, margin=2):
@@ -777,7 +890,7 @@ class LineTable:
         line1_key = "%.2f-%.2f-%.2f-%.2f"%(x1, y1, x2, y2)
         line2_key = "%.2f-%.2f-%.2f-%.2f"%(x3, y3, x4, y4)
         return point_is_exist, {"point": [x, y], "left": left, "right": right,
-                                "top": top, "buttom": buttom, "lines": set([line1_key,line2_key])}
+                                "top": top, "buttom": buttom, "lines": set([line1_key,line2_key]),"p_lines":[line1,line2]}
 
     def unionTable(self, list_table, fixspan=True, margin=2):
         set_x = set()
@@ -874,7 +987,14 @@ class LineTable:
         ta = {"bbox":table_bbox,"table":_table}
         return ta
 
-    def rect2table(self, list_textbox, list_rect, in_objs, margin=0.2, fixspan=True,sourceP_LB=True,fixRect=True):
+    def rect2table(self, list_textbox, list_rect, in_objs, margin=0.2, fixspan=False,sourceP_LB=True,fixRect=True):
+
+        def getIOU(bbox0,bbox1):
+            width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
+            height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1])
+            if width<0 and height<0:
+                return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
+            return 0
         _table = []
         set_x = set()
         set_y = set()
@@ -917,23 +1037,30 @@ class LineTable:
             l_cr.sort(key=lambda x:x.bbox[0])
             if fixRect:
 
+                pop_x = []
+                for _i in range(len(l_cr)-1):
+                    cr_i = len(l_cr)-_i-1
+                    if getIOU(l_cr[cr_i].bbox,l_cr[cr_i-1].bbox)>0.5:
+                        x0,y0,x1,y1 = l_cr[cr_i].bbox
+                        x2,y2,x3,y3 = l_cr[cr_i-1].bbox
+                        l_cr[cr_i-1].bbox = [min(x0,x2),min(y0,y2),max(x1,x3),max(y1,y3)]
+                        pop_x.append(cr_i)
+                for _x in pop_x:
+                    l_cr.pop(_x)
+                l_cr.sort(key=lambda x:x.bbox[0])
+
                 extend_cr = []
                 for cr_i in range(len(l_cr)):
                     if cr_i==0:
                         if abs(l_cr[cr_i].bbox[0]-list_x[0])>5:
                             extend_cr.append(LTRect(1,[list_x[0],l_cr[cr_i].bbox[1],l_cr[cr_i].bbox[0],l_cr[cr_i].bbox[3]]))
-                            print("=====extend0",[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],list_x[-1],l_cr[cr_i].bbox[3]])
 
                     if cr_i>=0 and cr_i<len(l_cr)-1:
                         if abs(l_cr[cr_i].bbox[2]-l_cr[cr_i+1].bbox[0])>5:
-                            extend_cr.append(LTRect(1,[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],l_cr[cr_i+1].bbox[2],l_cr[cr_i].bbox[3]]))
-                            print("=====extend1",[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],list_x[-1],l_cr[cr_i].bbox[3]])
-                            print(l_cr[cr_i].bbox)
-                            print(l_cr[cr_i+1].bbox)
+                            extend_cr.append(LTRect(1,[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],l_cr[cr_i+1].bbox[0],l_cr[cr_i].bbox[3]]))
 
                     if cr_i==len(l_cr)-1:
                         if abs(l_cr[cr_i].bbox[2]-list_x[-1])>5:
-                            print("=====extend",[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],list_x[-1],l_cr[cr_i].bbox[3]])
                             extend_cr.append(LTRect(1,[l_cr[cr_i].bbox[2],l_cr[cr_i].bbox[1],list_x[-1],l_cr[cr_i].bbox[3]]))
 
                 if extend_cr:
@@ -1012,12 +1139,12 @@ class LineTable:
                                 print(len(_table),l_i+i)
                                 _table[l_i+i].insert(c_i,_cell)
 
-        # print("=======")
-        # for _line in _table:
-        #     for _cell in _line:
-        #         print("[%s]"%_cell.get("text")[:10].replace("\n",''),end="\t\t")
-        #     print("\n")
-        # print("===========")
+        print("=======")
+        for _line in _table:
+            for _cell in _line:
+                print(_cell,end="\t\t")
+            print("\n")
+        print("===========")
 
         table_bbox = (_table[0][0].get("bbox")[0],
                       _table[0][0].get("bbox")[1],