Procházet zdrojové kódy

Merge remote-tracking branch 'origin/master'

fangjiasheng před 3 roky
rodič
revize
4d13aa7c7e
2 změnil soubory, kde provedl 176 přidání a 1 odebrání
  1. 2 1
      format_convert/convert.py
  2. 174 0
      otr/table_line.py

+ 2 - 1
format_convert/convert.py

@@ -2275,7 +2275,7 @@ def getText(_type, path_or_stream):
 
 
 
 
 def to_html(path, text):
 def to_html(path, text):
-    with open(path, 'w') as f:
+    with open(path, 'w',encoding="utf8") as f:
         f.write("<!DOCTYPE HTML>")
         f.write("<!DOCTYPE HTML>")
         f.write('<head><meta charset="UTF-8"></head>')
         f.write('<head><meta charset="UTF-8"></head>')
         f.write("<body>")
         f.write("<body>")
@@ -2673,6 +2673,7 @@ if __name__ == '__main__':
         # file_path = "C:/Users/Administrator/Desktop/Test_ODPS/1624875783055.pdf"
         # file_path = "C:/Users/Administrator/Desktop/Test_ODPS/1624875783055.pdf"
     else:
     else:
         file_path = "1.doc"
         file_path = "1.doc"
+    file_path = "files/1629873875150.png"
 
 
     with open(file_path, "rb") as f:
     with open(file_path, "rb") as f:
         file_bytes = f.read()
         file_bytes = f.read()

+ 174 - 0
otr/table_line.py

@@ -197,6 +197,175 @@ def table_net(input_shape=(1152, 896, 3), num_classes=1):
 model = table_net((None, None, 3), 2)
 model = table_net((None, None, 3), 2)
 
 
 
 
+def drawpixel(pred):
+    import matplotlib.pyplot as plt
+    _array = []
+    for _h in range(len(pred)):
+        _line = []
+        for _w in range(len(pred[_h])):
+            _prob = pred[_h][_w]
+            if _prob[0]>0.5:
+                _line.append((0,255,255))
+            elif _prob[1]>0.5:
+                _line.append((255,255,0))
+            else:
+                _line.append((255,255,255))
+        _array.append(_line)
+    plt.imshow(np.array(_array))
+    plt.show()
+
+def points2lines(pred,sourceP_LB=True,prob=0.2,line_width=7,padding=3,min_len=10,cell_width=13):
+
+    def inBbox(bbox,point,line_width):
+        x,y = point
+        if x>=bbox[0]-line_width and x<=bbox[2]+line_width and y>=bbox[1]-line_width and y<=bbox[3]+line_width:
+            return True,[min(x,bbox[0]),min(y,bbox[1]),max(x,bbox[2]),max(y,bbox[3])]
+        return False,None
+    height = len(pred)
+    width = len(pred[0])
+
+    clust_horizontal = []
+    clust_vertical = []
+    h_index = -1
+    _step = line_width
+    _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=1))
+    _last = False
+    _current = False
+    while 1:
+        h_index += 1
+        if h_index>=height:
+            break
+        w_index = -1
+        if sourceP_LB:
+            h_i = height-1-h_index
+        else:
+            h_i = h_index
+        if _sum[h_index]<min_len:
+            continue
+        while 1:
+            w_index += 2
+            if w_index>=width:
+                break
+            _v,_h = pred[h_index][w_index]
+
+            if _h>prob:
+
+                _find  = False
+                _point = (w_index,h_i)
+                for l_h_i in range(len(clust_vertical)):
+                    l_h = clust_vertical[len(clust_vertical)-l_h_i-1]
+                    bbox = l_h.get("bbox")
+                    b_in,_bbox = inBbox(bbox,_point,line_width)
+                    if b_in:
+                        _find = True
+                        l_h.get("points").append(_point)
+                        l_h["bbox"] = _bbox
+                        break
+                if not _find:
+                    clust_vertical.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
+    w_index = -1
+    _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=0))
+    while 1:
+        w_index += 1
+        if w_index>=width:
+            break
+        h_index = -1
+        if _sum[w_index]<min_len:
+            continue
+        while 1:
+            h_index += 2
+            if h_index>=height:
+                break
+
+            if sourceP_LB:
+                h_i = height-1-h_index
+            else:
+                h_i = h_index
+
+            _v,_h = pred[h_index][w_index]
+
+            if _v>prob:
+                _find  = False
+                _point = (w_index,h_i)
+                for l_h_i in range(len(clust_horizontal)):
+                    l_h = clust_horizontal[len(clust_horizontal)-l_h_i-1]
+                    bbox = l_h.get("bbox")
+                    b_in,_bbox = inBbox(bbox,_point,line_width)
+                    if b_in:
+                        _find = True
+                        l_h.get("points").append(_point)
+                        l_h["bbox"] = _bbox
+                        break
+                if not _find:
+                    clust_horizontal.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
+
+
+
+    tmp_vertical = []
+    for _dict in clust_vertical:
+        _bbox = _dict.get("bbox")
+        if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len:
+            tmp_vertical.append([(_bbox[0]+_bbox[2])/2,_bbox[1]-padding,(_bbox[0]+_bbox[2])/2,_bbox[3]+padding])
+    tmp_horizontal = []
+    for _dict in clust_horizontal:
+        _bbox = _dict.get("bbox")
+        if _bbox[2]-_bbox[0]>=min_len or _bbox[3]-_bbox[1]>=min_len:
+            tmp_horizontal.append([_bbox[0]-padding,(_bbox[1]+_bbox[3])/2,_bbox[2]+padding,(_bbox[1]+_bbox[3])/2])
+
+    #merge lines
+    tmp_vertical.sort(key=lambda x:x[3],reverse=True)
+    tmp_horizontal.sort(key=lambda x:x[0])
+
+    pop_index = []
+    final_vertical = []
+    for _line in tmp_vertical:
+        _find = False
+        x0,y0,x1,y1 = _line
+        for _line2 in final_vertical:
+            x2,y2,x3,y3 = _line2
+            if abs(x0-x2)<line_width and abs(y0-y3)<cell_width or abs(y1-y2)<cell_width:
+                _find = True
+                final_vertical.append([x0,min(y0,y2),x1,max(y1,y3)])
+                break
+        if not _find:
+            final_vertical.append(_line)
+
+    final_horizontal = []
+    for _line in tmp_horizontal:
+        _find = False
+        x0,y0,x1,y1 = _line
+        for _line2 in final_horizontal:
+            x2,y2,x3,y3 = _line2
+            if abs(y0-y2)<line_width and abs(x0-x3)<cell_width or abs(x1-x2)<cell_width:
+                _find = True
+                final_horizontal.append([min(x0,x2),y0,max(x1,x3),y1])
+                break
+        if not _find:
+            final_horizontal.append(_line)
+
+    list_line = []
+    for _line in final_vertical:
+        list_line.append(_line)
+    for _line in final_horizontal:
+        list_line.append(_line)
+
+
+    import matplotlib.pyplot as plt
+    plt.figure()
+    for _line in list_line:
+        x0,y0,x1,y1 = _line
+        plt.plot([x0,x1],[y0,y1])
+    # for _line in list_line:
+    #     x0,y0,x1,y1 = _line.bbox
+    #     plt.plot([x0,x1],[y0,y1])
+    # for point in list_crosspoints:
+    #     plt.scatter(point.get("point")[0],point.get("point")[1])
+    plt.show()
+
+
+
+
+
 def table_line(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=30, alph=15):
 def table_line(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=30, alph=15):
     sizew, sizeh = size
     sizew, sizeh = size
     # [..., ::-1] 最后一维内部反向输出
     # [..., ::-1] 最后一维内部反向输出
@@ -210,6 +379,11 @@ def table_line(img, model, size=(512, 1024), hprob=0.5, vprob=0.5, row=50, col=3
     # logging.info("into table_line 2")
     # logging.info("into table_line 2")
     pred = pred[0]
     pred = pred[0]
 
 
+    drawpixel(pred)
+    _time = time.time()
+    points2lines(pred)
+    logging.info("points2lines takes %ds"%(time.time()-_time))
+
     vpred = pred[..., 1] > vprob   # 横线
     vpred = pred[..., 1] > vprob   # 横线
     hpred = pred[..., 0] > hprob   # 竖线
     hpred = pred[..., 0] > hprob   # 竖线
     vpred = vpred.astype(int)
     vpred = vpred.astype(int)