1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import copy
- import os
- import sys
- import torch
- from torch.utils.data import DataLoader
- sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
- from BiddingKG.dl.table_head.models.model_torch import TableHeadModel
- from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head, set_label
- device = torch.device("cpu")
- model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_0.959.pth'
- batch_size = 1
- def predict(table_text_list):
- if globals().get("model") is None:
- print("="*15, "init table_head model", "="*15)
- # 实例化模型
- model = TableHeadModel()
- model.to(device)
- model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
- # 将模型设置为评估模式
- model.eval()
- globals()["model"] = model
- else:
- model = globals().get("model")
- if len(table_text_list) <= 0:
- return []
- data_x = copy.deepcopy(table_text_list)
- data_y = [[0 for col in row] for row in data_x]
- row_len = len(data_x)
- col_len = len(data_x[0])
- if col_len >= 50:
- return data_y
- if col_len >= 20:
- batch_row_len = 50
- else:
- batch_row_len = 100
- result_list = []
- for i in range(0, row_len, batch_row_len):
- batch_data_x = data_x[i:i+batch_row_len]
- dataset = CustomDatasetTiny40([batch_data_x], [data_y], mode=0)
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
- # 存储预测结果
- with torch.no_grad():
- for data, targets, _ in data_loader:
- data = data.to(device)
- outputs = model(data)
- outputs = set_same_table_head(data, outputs)
- result = torch.zeros_like(outputs)
- result[outputs >= 0.5] = 1
- result = result.numpy().tolist()
- result_list += result
- # 设置一些特定的表头
- for i in range(len(result_list)):
- row = table_text_list[i]
- row_label = result_list[i]
- result_list[i] = set_label(row, row_label)
- return result_list
|