Prechádzať zdrojové kódy

表头识别模型torch版

fangjiasheng 9 mesiacov pred
rodič
commit
c55b589784

+ 1 - 1
BiddingKG/dl/table_head/pre_process_torch.py

@@ -15,7 +15,7 @@ def set_label(row, row_label):
     elif len(set(row)) == 1:
         row_label = [0 for x in row]
     else:
-        row_label = [0 if x in ["", " ", "/", '无', '-'] else row_label[i] for i, x in enumerate(row)]
+        row_label = [0 if x in ["", " ", "/", '无', '-', '~~'] else row_label[i] for i, x in enumerate(row)]
     return row_label
 
 

+ 7 - 1
BiddingKG/dl/table_head/predict_torch.py

@@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
 
 sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
 from BiddingKG.dl.table_head.models.model_torch import TableHeadModel
-from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head
+from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head, set_label
 
 device = torch.device("cpu")
 model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_0.959.pth'
@@ -59,4 +59,10 @@ def predict(table_text_list):
                 result = result.numpy().tolist()
         result_list += result
 
+    # 设置一些特定的表头
+    for i in range(len(result_list)):
+        row = table_text_list[i]
+        row_label = result_list[i]
+        result_list[i] = set_label(row, row_label)
+
     return result_list