Browse Source

调整文本放进表格的逻辑,优先选择iou最高的格子

luojiehua 2 years ago
parent
commit
62b432c189
1 changed files with 62 additions and 59 deletions
  1. 62 59
      format_convert/utils.py

+ 62 - 59
format_convert/utils.py

@@ -32,6 +32,7 @@ import numpy as np
 from format_convert.judge_platform import get_platform
 from format_convert.judge_platform import get_platform
 if get_platform() == "Linux":
 if get_platform() == "Linux":
     import resource
     import resource
+import math
 
 
 
 
 def judge_error_code(_list, code=[0, -1, -2, -3, -4, -5, -6, -7, -8, -9]):
 def judge_error_code(_list, code=[0, -1, -2, -3, -4, -5, -6, -7, -8, -9]):
@@ -1297,69 +1298,71 @@ class LineTable:
         list_textbox.sort(key=lambda x:x.bbox[0])
         list_textbox.sort(key=lambda x:x.bbox[0])
         list_textbox.sort(key=lambda x:x.bbox[3],reverse=sourceP_LB)
         list_textbox.sort(key=lambda x:x.bbox[3],reverse=sourceP_LB)
         # print("list_textbox", list_textbox)
         # print("list_textbox", list_textbox)
-        # for textbox in list_textbox:
-        #     (x0,y0,x1,y1) = textbox.bbox
-        #     _text = textbox.get_text()
-        #     _find = False
-        #     for table_line in _table:
-        #         for _cell in table_line:
-        #             if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text()):
-        #                 _cell["text"] += _text
-        #                 in_objs.add(textbox)
-        #                 _find = True
-        #                 break
-        #         if _find:
-        #             break
-        has_matched_box_list = []
+
+
+        #find the suitable cell of the textbox
+        list_cells = []
         for table_line in _table:
         for table_line in _table:
             for _cell in table_line:
             for _cell in table_line:
-                inbox_textbox_list = []
-                for textbox in list_textbox:
-                    if textbox in in_objs:
-                        continue
-                    if self.inbox(textbox.bbox, _cell["bbox"], textbox.get_text())==1:
-                        inbox_textbox_list.append(textbox)
-                        in_objs.add(textbox)
-
-                # 分行,根据y重合
-                all_match_box_list = []
-                if sourceP_LB:
-                    inbox_textbox_list.sort(key=lambda x:x.bbox[1],reverse=True)
-                else:
-                    inbox_textbox_list.sort(key=lambda x:x.bbox[1])
-                for i in range(len(inbox_textbox_list)):
-                    match_box_list = []
-                    box1 = inbox_textbox_list[i]
-                    if box1 in has_matched_box_list:
+                list_cells.append({"cell":_cell,"inbox_textbox_list":[]})
+
+        for textbox in list_textbox:
+            list_iou = []
+            for _d in list_cells:
+                _cell = _d["cell"]
+                _iou = self.getIOU(textbox.bbox,_cell["bbox"])
+                list_iou.append(_iou)
+            max_iou_index = np.argmax(list_iou)
+            max_iou = list_iou[max_iou_index]
+            if max_iou>0.1:
+                list_cells[max_iou_index]["inbox_textbox_list"].append(textbox)
+                in_objs.add(textbox)
+
+
+        has_matched_box_list = []
+        for _d in list_cells:
+            _cell = _d["cell"]
+            inbox_textbox_list = _d["inbox_textbox_list"]
+
+            # 分行,根据y重合
+            all_match_box_list = []
+            if sourceP_LB:
+                inbox_textbox_list.sort(key=lambda x:x.bbox[1],reverse=True)
+            else:
+                inbox_textbox_list.sort(key=lambda x:x.bbox[1])
+            for i in range(len(inbox_textbox_list)):
+                match_box_list = []
+                box1 = inbox_textbox_list[i]
+                if box1 in has_matched_box_list:
+                    continue
+
+                min_y1 = box1.bbox[1] + 1/3 * abs(box1.bbox[3]-box1.bbox[1])
+                max_y1 = box1.bbox[3] - 1/3 * abs(box1.bbox[3]-box1.bbox[1])
+                match_box_list.append([box1.get_text(), box1.bbox[0], box1.bbox[1], box1.bbox[2], box1.bbox[3],min_y1,max_y1])
+                has_matched_box_list.append(box1)
+                for j in range(i+1, len(inbox_textbox_list)):
+                    box2 = inbox_textbox_list[j]
+                    if box2 in has_matched_box_list:
                         continue
                         continue
 
 
-                    min_y1 = box1.bbox[1] + 1/3 * abs(box1.bbox[3]-box1.bbox[1])
-                    max_y1 = box1.bbox[3] - 1/3 * abs(box1.bbox[3]-box1.bbox[1])
-                    match_box_list.append([box1.get_text(), box1.bbox[0], box1.bbox[1], box1.bbox[2], box1.bbox[3],min_y1,max_y1])
-                    has_matched_box_list.append(box1)
-                    for j in range(i+1, len(inbox_textbox_list)):
-                        box2 = inbox_textbox_list[j]
-                        if box2 in has_matched_box_list:
-                            continue
-
-                        # print(min_y1, box2.bbox[1], box2.bbox[3], max_y1)
-                        # print(min_y2, box1.bbox[3], max_y2)
-                        if min_y1 <= box2.bbox[1] <= max_y1 or \
-                                min_y1 <= box2.bbox[3] <= max_y1 or \
-                                box2.bbox[1] <= min_y1 <= max_y1 <= box2.bbox[3]:
-                            match_box_list.append([box2.get_text(), box2.bbox[0], box2.bbox[1], box2.bbox[2], box2.bbox[3],min_y1,max_y1])
-                            has_matched_box_list.append(box2)
-                    match_box_list.sort(key=lambda x: x[1])
-                    all_match_box_list.append(match_box_list)
-
-                # print("match_box_list", all_match_box_list)
-                if sourceP_LB:
-                    all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2,reverse=True)
-                else:
-                    all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2)
-                for box_list in all_match_box_list:
-                    for box in box_list:
-                        _cell["text"] += box[0]
+                    # print(min_y1, box2.bbox[1], box2.bbox[3], max_y1)
+                    # print(min_y2, box1.bbox[3], max_y2)
+                    if min_y1 <= box2.bbox[1] <= max_y1 or \
+                            min_y1 <= box2.bbox[3] <= max_y1 or \
+                            box2.bbox[1] <= min_y1 <= max_y1 <= box2.bbox[3]:
+                        match_box_list.append([box2.get_text(), box2.bbox[0], box2.bbox[1], box2.bbox[2], box2.bbox[3],min_y1,max_y1])
+                        has_matched_box_list.append(box2)
+                match_box_list.sort(key=lambda x: x[1])
+                all_match_box_list.append(match_box_list)
+
+            # print("match_box_list", all_match_box_list)
+            if sourceP_LB:
+                all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2,reverse=True)
+            else:
+                all_match_box_list.sort(key=lambda x:(x[0][2]+x[0][4])/2)
+            for box_list in all_match_box_list:
+                for box in box_list:
+                    _cell["text"] += box[0]
 
 
         # print("------------")
         # print("------------")
         # for _line in _table:
         # for _line in _table: