Ver Fonte

修复点到线的识别

luojiehua há 3 anos atrás
pai
commit
95f38835e0
2 ficheiros alterados com 20 adições e 11 exclusões
  1. 12 10
      otr/table_line.py
  2. 8 1
      result.html

+ 12 - 10
otr/table_line.py

@@ -230,7 +230,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
     clust_vertical = []
     clust_vertical = []
     h_index = -1
     h_index = -1
     _step = line_width
     _step = line_width
-    _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=1))
+    _sum = list(np.sum(np.array((pred[...,1]>prob)).astype(int),axis=0))
     _last = False
     _last = False
     _current = False
     _current = False
     while 1:
     while 1:
@@ -242,15 +242,15 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
             h_i = height-1-h_index
             h_i = height-1-h_index
         else:
         else:
             h_i = h_index
             h_i = h_index
-        if _sum[h_index]<min_len:
-            continue
         while 1:
         while 1:
             w_index += 2
             w_index += 2
             if w_index>=width:
             if w_index>=width:
                 break
                 break
-            _v,_h = pred[h_index][w_index]
+            if _sum[w_index]<min_len:
+                continue
+            _h,_v = pred[h_index][w_index]
 
 
-            if _h>prob:
+            if _v>prob:
 
 
                 _find  = False
                 _find  = False
                 _point = (w_index,h_i)
                 _point = (w_index,h_i)
@@ -265,28 +265,29 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
                         break
                         break
                 if not _find:
                 if not _find:
                     clust_vertical.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
                     clust_vertical.append({"points":[_point],"bbox":[w_index,h_i,w_index,h_i]})
+
     w_index = -1
     w_index = -1
-    _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=0))
+    _sum = list(np.sum(np.array((pred[...,0]>prob)).astype(int),axis=1))
     while 1:
     while 1:
         w_index += 1
         w_index += 1
         if w_index>=width:
         if w_index>=width:
             break
             break
         h_index = -1
         h_index = -1
-        if _sum[w_index]<min_len:
-            continue
         while 1:
         while 1:
             h_index += 2
             h_index += 2
             if h_index>=height:
             if h_index>=height:
                 break
                 break
+            if _sum[h_index]<min_len:
+                continue
 
 
             if sourceP_LB:
             if sourceP_LB:
                 h_i = height-1-h_index
                 h_i = height-1-h_index
             else:
             else:
                 h_i = h_index
                 h_i = h_index
 
 
-            _v,_h = pred[h_index][w_index]
+            _h,_v = pred[h_index][w_index]
 
 
-            if _v>prob:
+            if _h>prob:
                 _find  = False
                 _find  = False
                 _point = (w_index,h_i)
                 _point = (w_index,h_i)
                 for l_h_i in range(len(clust_horizontal)):
                 for l_h_i in range(len(clust_horizontal)):
@@ -303,6 +304,7 @@ def points2lines(pred, sourceP_LB=True, prob=0.2, line_width=7, padding=3, min_l
 
 
 
 
 
 
+
     tmp_vertical = []
     tmp_vertical = []
     for _dict in clust_vertical:
     for _dict in clust_vertical:
         _bbox = _dict.get("bbox")
         _bbox = _dict.get("bbox")

+ 8 - 1
result.html

@@ -18,9 +18,16 @@
 </tr>
 </tr>
 </table>
 </table>
 <div>四、主要标的信息</div>
 <div>四、主要标的信息</div>
-<div>货物类</div>
 <table border="1">
 <table border="1">
 <tr>
 <tr>
+<td colspan=1 rowspan=1>货物类</td>
+<td colspan=1 rowspan=1>货物类</td>
+<td colspan=1 rowspan=1>货物类</td>
+<td colspan=1 rowspan=1>货物类</td>
+<td colspan=1 rowspan=1>货物类</td>
+<td colspan=1 rowspan=1>货物类</td>
+</tr>
+<tr>
 <td colspan=1 rowspan=1>供应商名称</td>
 <td colspan=1 rowspan=1>供应商名称</td>
 <td colspan=1 rowspan=1>名称</td>
 <td colspan=1 rowspan=1>名称</td>
 <td colspan=1 rowspan=1>品牌</td>
 <td colspan=1 rowspan=1>品牌</td>