import copy import os import sys import torch from torch.utils.data import DataLoader sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../")) from BiddingKG.dl.table_head.models.model_torch import TableHeadModel # from BiddingKG.dl.table_head.models.model_torch import TableHeadModel2 from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head, set_label device = torch.device("cpu") model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_0.959.pth' # model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_2_0.959.pth' batch_size = 1 def predict(table_text_list): if globals().get("model") is None: print("="*15, "init table_head model", "="*15) # 实例化模型 model = TableHeadModel() # model = TableHeadModel2() model.to(device) model.load_state_dict(torch.load(model_path, map_location=torch.device(device))) # 将模型设置为评估模式 model.eval() globals()["model"] = model else: model = globals().get("model") if len(table_text_list) <= 0: return [] data_x = copy.deepcopy(table_text_list) data_y = [[0 for col in row] for row in data_x] row_len = len(data_x) col_len = len(data_x[0]) if col_len >= 50: return data_y if col_len >= 20: batch_row_len = 50 else: batch_row_len = 100 result_list = [] for i in range(0, row_len, batch_row_len): batch_data_x = data_x[i:i+batch_row_len] dataset = CustomDatasetTiny40([batch_data_x], [data_y], mode=0) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) # 存储预测结果 with torch.no_grad(): for data, targets, _ in data_loader: data = data.to(device) outputs = model(data) outputs = set_same_table_head(data, outputs) result = torch.zeros_like(outputs) result[outputs >= 0.5] = 1 result = result.numpy().tolist() result_list += result # 设置一些特定的表头 for i in range(len(result_list)): row = table_text_list[i] row_label = result_list[i] result_list[i] = set_label(row, row_label) return result_list