|
@@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
|
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
|
|
|
from BiddingKG.dl.table_head.models.model_torch import TableHeadModel
|
|
|
-from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head
|
|
|
+from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head, set_label
|
|
|
|
|
|
device = torch.device("cpu")
|
|
|
model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_0.959.pth'
|
|
@@ -59,4 +59,10 @@ def predict(table_text_list):
|
|
|
result = result.numpy().tolist()
|
|
|
result_list += result
|
|
|
|
|
|
+ # 设置一些特定的表头
|
|
|
+ for i in range(len(result_list)):
|
|
|
+ row = table_text_list[i]
|
|
|
+ row_label = result_list[i]
|
|
|
+ result_list[i] = set_label(row, row_label)
|
|
|
+
|
|
|
return result_list
|