Просмотр исходного кода

解决点到表格的识别补充span和rect的问题以及文字排序错乱问题

luojiehua 2 лет назад
Родитель
Сommit
a428518d6d
3 измененных файлов с 170 добавлено и 97 удалено
  1. 3 3
      format_convert/convert_image.py
  2. 166 93
      format_convert/utils.py
  3. 1 1
      otr/table_line.py

+ 3 - 3
format_convert/convert_image.py

@@ -204,9 +204,9 @@ def image_process(image_np, image_path, is_from_pdf=False, is_from_docx=False, u
                             [int(point[2][0]*ratio[1]), int(point[2][1]*ratio[0])],
                             [int(point[3][0]*ratio[1]), int(point[3][1]*ratio[0])]]
 
-        # 对文字框的y进行聚类
-        text_list, bbox_list = get_cluster(text_list, bbox_list, 0)
-        text_list, bbox_list = get_cluster(text_list, bbox_list, 2)
+        # for _a,_b in zip(text_list,bbox_list):
+        #     print("bbox1",_a,_b)
+
 
         # 调用现成方法形成表格
         try:

+ 166 - 93
format_convert/utils.py

@@ -795,7 +795,7 @@ class LineTable:
 
         return list_l_rect
 
-    def crosspoint2rect(self, list_crosspoint, margin=5):
+    def crosspoint2rect(self, list_crosspoint, margin=10):
 
         dict_line_points = {}
         for _point in list_crosspoint:
@@ -993,12 +993,19 @@ class LineTable:
         list_x.sort(key=lambda x:x)
         list_y.sort(key=lambda x:x,reverse=True)
         _table = []
+        line_i = 0
         for _line in clusters_rects:
+
             table_line = []
+            cell_i = 0
             for _rect in _line:
                 (x0,y0,x1,y1) = _rect.get("bbox")
                 _cell = {"bbox":(x0,y0,x1,y1),"rect":_rect.get("rect"),"rowspan":self.getspan(list_y,y0,y1,margin),"columnspan":self.getspan(list_x,x0,x1,margin),"text":_rect.get("text","")}
                 table_line.append(_cell)
+
+
+                cell_i += 1
+            line_i += 1
             _table.append(table_line)
 
         # print("=====================>>")
@@ -1037,7 +1044,135 @@ class LineTable:
         ta = {"bbox":table_bbox,"table":_table}
         return ta
 
-    def rect2table(self, list_textbox, list_rect, in_objs, margin=5, fixspan=True,sourceP_LB=True,fixRect=True):
+    #获取点阵
+    def getSpanLocation(self,_list, x0, x1, margin):
+        list_location = []
+        (x0,x1) = (min(x0,x1),max(x0,x1))
+        for _x in _list:
+            if _x>=(x0-margin) and _x<=(x1+margin):
+                list_location.append(_x)
+        return list_location
+
+    def fixSpan(self,_table,list_x,list_y):
+        for _line in _table:
+            c_i = 0
+            while c_i<len(_line):
+                _cell = _line[c_i]
+
+                if _cell.get("columnspan")>1:
+                    x0,y0,x1,y1 = _cell.get("bbox")
+                    _cospan = _cell.get("columnspan")
+
+                    locations = self.getSpanLocation(list_x,x0,x1,10)
+                    if len(locations)==_cospan+1:
+                        _cell["bbox"] = (x0,y0,locations[1],y1)
+
+                        _cell["columnspan"] = 1
+
+                        #len(locations)==_colspan+1
+                        for i in range(1,_cospan):
+                            n_cell = {}
+                            n_cell.update(_cell)
+                            n_cell["bbox"] = (locations[i],y0,locations[i+1],y1)
+                            c_i += 1
+                            _line.insert(c_i,n_cell)
+
+                c_i += 1
+        for l_i in range(len(_table)):
+            _line = _table[l_i]
+            c_i = 0
+            while c_i<len(_line):
+                _cell = _line[c_i]
+                if _cell.get("rowspan")>1:
+                    x0,y0,x1,y1 = _cell.get("bbox")
+                    _rospan = _cell.get("rowspan")
+                    locations = self.getSpanLocation(list_y,y0,y1,10)
+
+                    if len(locations)==_rospan+1:
+                        _cell["bbox"] = (x0,y0,x1,locations[1])
+                        _cell["rowspan"] = 1
+
+                        for i in range(1,_rospan):
+                            n_cell = {}
+                            n_cell.update(_cell)
+                            if l_i+i<=len(_table)-1:
+                                # print(len(_table),l_i+i)
+                                n_cell["bbox"] = (x0,locations[i],x1,locations[i+1])
+                                _table[l_i+i].insert(c_i,n_cell)
+
+                c_i += 1
+
+
+    def fixRect(self,_table,list_x,list_y,list_textbox,in_objs,sourceP_LB,margin):
+        self.fixSpan(_table,list_x,list_y)
+        # for line_i in range(len(_table)):
+        #     for cell_i in range(len(_table[line_i])):
+        #         _cell = _table[line_i][cell_i]
+        #         print(line_i,cell_i,_cell["bbox"],_cell["text"])
+        for _line in _table:
+            extend_line = []
+            for c_i in range(len(_line)):
+                c_cell = _line[c_i]
+
+                #first cell missing
+                if c_i==0 and c_cell["bbox"][0]!=list_x[0]:
+                    _bbox = (list_x[0],c_cell["bbox"][1], c_cell["bbox"][0],c_cell["bbox"][3])
+                    _cell = {"bbox": _bbox,
+                             "rect": LTRect(1,_bbox),
+                             "rowspan": self.getspan(list_y,_bbox[1], _bbox[3], margin),
+                             "columnspan": self.getspan(list_x, _bbox[0], _bbox[2], margin),
+                             "text": ""}
+                    extend_line.append({"index":c_i,"cell":_cell})
+                #cell in the median missing
+                if c_i<len(_line)-1:
+                    n_cell = _line[c_i+1]
+                    _bbox = c_cell["bbox"]
+                    n_bbox = n_cell["bbox"]
+                    if _bbox[0]==n_bbox[0] and _bbox[2]==n_bbox[2]:
+                        continue
+                    else:
+                        if abs(_bbox[2]-n_bbox[0])>margin:
+                            _bbox = (_bbox[2],_bbox[1], n_bbox[0],_bbox[3])
+                            _cell = {"bbox": _bbox,
+                                     "rect": LTRect(1,_bbox),
+                                     "rowspan": self.getspan(list_y,_bbox[1], _bbox[3], margin),
+                                     "columnspan": self.getspan(list_x, _bbox[0], _bbox[2], margin),
+                                     "text": ""}
+                            extend_line.append({"index":c_i+1,"cell":_cell})
+                #last cell missing
+                if c_i==len(_line)-1:
+                    if abs(c_cell["bbox"][2]-list_x[-1])>margin:
+                        _bbox = (c_cell["bbox"][2],c_cell["bbox"][1], list_x[-1],c_cell["bbox"][3])
+                        _cell = {"bbox": _bbox,
+                                 "rect": LTRect(1,_bbox),
+                                 "rowspan": self.getspan(list_y,_bbox[1], _bbox[3], margin),
+                                 "columnspan": self.getspan(list_x, _bbox[0], _bbox[2], margin),
+                                 "text": ""}
+                        extend_line.append({"index":c_i+1,"cell":_cell})
+            extend_line.sort(key=lambda x: x["index"],reverse=True)
+
+            for _tmp in extend_line:
+                _line.insert(_tmp["index"],_tmp["cell"])
+
+            list_textbox.sort(key=lambda x: x.bbox[0])
+            list_textbox.sort(key=lambda x: x.bbox[3], reverse=sourceP_LB)
+            for textbox in list_textbox:
+                if textbox in in_objs:
+                    continue
+                x0, y0, x1, y1 = textbox.bbox
+                _text = textbox.get_text()
+                _find = False
+                for table_line in _table:
+                    for _cell in table_line:
+                        if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
+                            _cell["text"] += _text
+                            in_objs.add(textbox)
+                            _find = True
+                            break
+                    if _find:
+                        break
+
+    def rect2table(self, list_textbox, list_rect, in_objs, margin=5, fixspan=False,sourceP_LB=True,fixRect=True):
 
         def getIOU(bbox0,bbox1):
             width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
@@ -1045,6 +1180,8 @@ class LineTable:
             if width<0 and height<0:
                 return abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))
             return 0
+
+
         _table = []
         set_x = set()
         set_y = set()
@@ -1133,8 +1270,10 @@ class LineTable:
 
         # print(list_x)
         # print(list_y)
+        line_i = 0
         for _line in clusters_rects:
             table_line = []
+            cell_i = 0
             for _rect in _line:
                 (x0, y0, x1, y1) = _rect.bbox
                 _cell = {"bbox": (x0, y0, x1, y1),
@@ -1142,9 +1281,19 @@ class LineTable:
                          "rowspan": self.getspan(list_y, y0, y1, margin),
                          "columnspan": self.getspan(list_x, x0, x1, margin),
                          "text": ""}
+
+                cell_i += 1
                 table_line.append(_cell)
+            line_i += 1
             _table.append(table_line)
 
+        # print("table===========================>")
+        # for _line in _table:
+        #     for _cell in _line:
+        #         print("||%d%d"%(_cell["rowspan"],_cell["columnspan"]),end="\t")
+        #     print()
+        # print("table===========================>")
+
         list_textbox.sort(key=lambda x:x.bbox[0])
         list_textbox.sort(key=lambda x:x.bbox[3],reverse=sourceP_LB)
         # print("list_textbox", list_textbox)
@@ -1168,12 +1317,16 @@ class LineTable:
                 for textbox in list_textbox:
                     if textbox in in_objs:
                         continue
-                    if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
+                    if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text())==1:
                         inbox_textbox_list.append(textbox)
                         in_objs.add(textbox)
 
                 # 分行,根据y重合
                 all_match_box_list = []
+                if sourceP_LB:
+                    inbox_textbox_list.sort(key=lambda x:x.bbox[1],reverse=True)
+                else:
+                    inbox_textbox_list.sort(key=lambda x:x.bbox[1])
                 for i in range(len(inbox_textbox_list)):
                     match_box_list = []
                     box1 = inbox_textbox_list[i]
@@ -1182,7 +1335,7 @@ class LineTable:
 
                     min_y1 = box1.bbox[1] + 1/3 * abs(box1.bbox[3]-box1.bbox[1])
                     max_y1 = box1.bbox[3] - 1/3 * abs(box1.bbox[3]-box1.bbox[1])
-                    match_box_list.append([box1.get_text(), box1.bbox[0], box1.bbox[1], box1.bbox[2], box1.bbox[3]])
+                    match_box_list.append([box1.get_text(), box1.bbox[0], box1.bbox[1], box1.bbox[2], box1.bbox[3],min_y1,max_y1])
                     has_matched_box_list.append(box1)
                     for j in range(i+1, len(inbox_textbox_list)):
                         box2 = inbox_textbox_list[j]
@@ -1194,12 +1347,16 @@ class LineTable:
                         if min_y1 <= box2.bbox[1] <= max_y1 or \
                                 min_y1 <= box2.bbox[3] <= max_y1 or \
                                 box2.bbox[1] <= min_y1 <= max_y1 <= box2.bbox[3]:
-                            match_box_list.append([box2.get_text(), box2.bbox[0], box2.bbox[1], box2.bbox[2], box2.bbox[3]])
+                            match_box_list.append([box2.get_text(), box2.bbox[0], box2.bbox[1], box2.bbox[2], box2.bbox[3],min_y1,max_y1])
                             has_matched_box_list.append(box2)
                     match_box_list.sort(key=lambda x: x[1])
                     all_match_box_list.append(match_box_list)
 
                 # print("match_box_list", all_match_box_list)
+                if sourceP_LB:
+                    all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2,reverse=True)
+                else:
+                    all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2)
                 for box_list in all_match_box_list:
                     for box in box_list:
                         _cell["text"] += box[0]
@@ -1211,92 +1368,8 @@ class LineTable:
         #         print("\n")
         # print("------------")
 
-        if fixspan:
-            for _line in _table:
-                for c_i in range(len(_line)):
-                    _cell = _line[c_i]
+        self.fixRect(_table,list_x,list_y,list_textbox,in_objs,sourceP_LB,margin)
 
-                    if _cell.get("columnspan")>1:
-                        _cospan = _cell.get("columnspan")
-                        _cell["columnspan"] = 1
-                        n_cell = {}
-                        n_cell.update(_cell)
-                        for i in range(1,_cospan):
-                            _line.insert(c_i,n_cell)
-            for l_i in range(len(_table)):
-                _line = _table[l_i]
-                for c_i in range(len(_line)):
-                    _cell = _line[c_i]
-                    if _cell.get("rowspan")>1:
-                        _rospan = _cell.get("rowspan")
-                        _cell["rowspan"] = 1
-                        n_cell = {}
-                        n_cell.update(_cell)
-                        for i in range(1,_rospan):
-                            if l_i+i<=len(_table)-1:
-                                # print(len(_table),l_i+i)
-                                _table[l_i+i].insert(c_i,n_cell)
-
-        if fixRect:
-            for _line in _table:
-                extend_line = []
-                for c_i in range(len(_line)):
-                    c_cell = _line[c_i]
-
-                    if c_i==0 and c_cell["bbox"][0]!=list_x[0]:
-                        _bbox = (list_x[0],c_cell["bbox"][1], c_cell["bbox"][0],c_cell["bbox"][3])
-                        _cell = {"bbox": _bbox,
-                                 "rect": LTRect(1,_bbox),
-                                 "rowspan": self.getspan(list_y,_bbox[1], _bbox[3], margin),
-                                 "columnspan": self.getspan(list_x, _bbox[0], _bbox[2], margin),
-                                 "text": ""}
-                        extend_line.append({"index":c_i,"cell":_cell})
-                    if c_i<len(_line)-1:
-                        n_cell = _line[c_i+1]
-                        _bbox = c_cell["bbox"]
-                        n_bbox = n_cell["bbox"]
-                        if _bbox[0]==n_bbox[0] and _bbox[2]==n_bbox[2]:
-                            continue
-                        else:
-                            if abs(_bbox[2]-n_bbox[0])>margin:
-                                _bbox = (_bbox[2],_bbox[1], n_bbox[0],_bbox[3])
-                                _cell = {"bbox": _bbox,
-                                         "rect": LTRect(1,_bbox),
-                                         "rowspan": self.getspan(list_y,_bbox[1], _bbox[3], margin),
-                                         "columnspan": self.getspan(list_x, _bbox[0], _bbox[2], margin),
-                                         "text": ""}
-                                extend_line.append({"index":c_i+1,"cell":_cell})
-                    if c_i==len(_line)-1:
-                        if abs(c_cell["bbox"][2]-list_x[-1])>margin:
-                            _bbox = (c_cell["bbox"][2],c_cell["bbox"][1], list_x[-1],c_cell["bbox"][3])
-                            _cell = {"bbox": _bbox,
-                                     "rect": LTRect(1,_bbox),
-                                     "rowspan": self.getspan(list_y,_bbox[1], _bbox[3], margin),
-                                     "columnspan": self.getspan(list_x, _bbox[0], _bbox[2], margin),
-                                     "text": ""}
-                            extend_line.append({"index":c_i+1,"cell":_cell})
-                extend_line.sort(key=lambda x: x["index"],reverse=True)
-
-                for _tmp in extend_line:
-                    _line.insert(_tmp["index"],_tmp["cell"])
-
-                list_textbox.sort(key=lambda x: x.bbox[0])
-                list_textbox.sort(key=lambda x: x.bbox[3], reverse=sourceP_LB)
-                for textbox in list_textbox:
-                    if textbox in in_objs:
-                        continue
-                    x0, y0, x1, y1 = textbox.bbox
-                    _text = textbox.get_text()
-                    _find = False
-                    for table_line in _table:
-                        for _cell in table_line:
-                            if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
-                                _cell["text"] += _text
-                                in_objs.add(textbox)
-                                _find = True
-                                break
-                        if _find:
-                            break
 
         table_bbox = (_table[0][0].get("bbox")[0],
                       _table[0][0].get("bbox")[1],
@@ -1317,13 +1390,13 @@ class LineTable:
         # if bbox_g[0]<=bbox0[0] and bbox_g[1]<=bbox0[1] and bbox_g[2]>=bbox0[2] and bbox_g[3]>=bbox0[3]:
         #     return 1
         # print("utils inbox", text, self.getIOU(bbox0,bbox_g), bbox0, bbox_g)
-        if self.getIOU(bbox0,bbox_g)>0.5:
+        if self.getIOU(bbox0,bbox_g)>0.2:
             return 1
         return 0
 
     def getIOU(self, bbox0, bbox1):
-        width = max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0])-(bbox0[2]-bbox0[0]+bbox1[2]-bbox1[0])
-        height = max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1])-(bbox0[3]-bbox0[1]+bbox1[3]-bbox1[1])
+        width = abs(max(bbox0[2],bbox1[2])-min(bbox0[0],bbox1[0]))-(abs(bbox0[2]-bbox0[0])+abs(bbox1[2]-bbox1[0]))
+        height = abs(max(bbox0[3],bbox1[3])-min(bbox0[1],bbox1[1]))-(abs(bbox0[3]-bbox0[1])+abs(bbox1[3]-bbox1[1]))
         if width < 0 and height < 0:
             iou = abs(width*height/min(abs((bbox0[2]-bbox0[0])*(bbox0[3]-bbox0[1])),
                                        abs((bbox1[2]-bbox1[0])*(bbox1[3]-bbox1[1]))))

+ 1 - 1
otr/table_line.py

@@ -697,7 +697,7 @@ def table_postprocess(img_new, pred, prob=0.2, is_test=0):
         return [-1]
 
 
-def table_line(img, model, size=(512, 1024), prob=0.2, is_test=0):
+def table_line(img, model, size=(512, 1024), prob=0.2, is_test=1):
     log("into table_line, prob is " + str(prob))
     sizew, sizeh = size
     img_new = cv2.resize(img, (sizew, sizeh), interpolation=cv2.INTER_AREA)