predict_torch.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import copy
  2. import os
  3. import sys
  4. import torch
  5. from torch.utils.data import DataLoader
  6. sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
  7. from BiddingKG.dl.table_head.models.model_torch import TableHeadModel
  8. # from BiddingKG.dl.table_head.models.model_torch import TableHeadModel2
  9. from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head, set_label
  10. device = torch.device("cpu")
  11. model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_0.959.pth'
  12. # model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_2_0.959.pth'
  13. batch_size = 1
  14. def predict(table_text_list):
  15. if globals().get("model") is None:
  16. print("="*15, "init table_head model", "="*15)
  17. # 实例化模型
  18. model = TableHeadModel()
  19. # model = TableHeadModel2()
  20. model.to(device)
  21. model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
  22. # 将模型设置为评估模式
  23. model.eval()
  24. globals()["model"] = model
  25. else:
  26. model = globals().get("model")
  27. if len(table_text_list) <= 0:
  28. return []
  29. data_x = copy.deepcopy(table_text_list)
  30. data_y = [[0 for col in row] for row in data_x]
  31. row_len = len(data_x)
  32. col_len = len(data_x[0])
  33. if col_len >= 50:
  34. return data_y
  35. if col_len >= 20:
  36. batch_row_len = 50
  37. else:
  38. batch_row_len = 100
  39. result_list = []
  40. for i in range(0, row_len, batch_row_len):
  41. batch_data_x = data_x[i:i+batch_row_len]
  42. dataset = CustomDatasetTiny40([batch_data_x], [data_y], mode=0)
  43. data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
  44. # 存储预测结果
  45. with torch.no_grad():
  46. for data, targets, _ in data_loader:
  47. data = data.to(device)
  48. outputs = model(data)
  49. outputs = set_same_table_head(data, outputs)
  50. result = torch.zeros_like(outputs)
  51. result[outputs >= 0.5] = 1
  52. result = result.numpy().tolist()
  53. result_list += result
  54. # 设置一些特定的表头
  55. for i in range(len(result_list)):
  56. row = table_text_list[i]
  57. row_label = result_list[i]
  58. result_list[i] = set_label(row, row_label)
  59. return result_list