predict_torch.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import copy
  2. import os
  3. import sys
  4. import torch
  5. from torch.utils.data import DataLoader
  6. sys.path.append(os.path.abspath(os.path.dirname(__file__) + "/../../../"))
  7. from BiddingKG.dl.table_head.models.model_torch import TableHeadModel
  8. from BiddingKG.dl.table_head.pre_process_torch import CustomDatasetTiny40, set_same_table_head, set_label
  9. device = torch.device("cpu")
  10. model_path = os.path.abspath(os.path.dirname(__file__)) + '/model_40_0.951.pth'
  11. batch_size = 1
  12. def predict(table_text_list):
  13. if globals().get("model") is None:
  14. print("="*15, "init table_head model", "="*15)
  15. # 实例化模型
  16. model = TableHeadModel()
  17. model.to(device)
  18. model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))
  19. # 将模型设置为评估模式
  20. model.eval()
  21. globals()["model"] = model
  22. else:
  23. model = globals().get("model")
  24. if len(table_text_list) <= 0:
  25. return []
  26. data_x = copy.deepcopy(table_text_list)
  27. data_y = [[0 for col in row] for row in data_x]
  28. row_len = len(data_x)
  29. col_len = len(data_x[0])
  30. if col_len >= 50:
  31. return data_y
  32. if col_len >= 20:
  33. batch_row_len = 50
  34. else:
  35. batch_row_len = 100
  36. result_list = []
  37. for i in range(0, row_len, batch_row_len):
  38. batch_data_x = data_x[i:i+batch_row_len]
  39. dataset = CustomDatasetTiny40([batch_data_x], [data_y], mode=0)
  40. data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
  41. # 存储预测结果
  42. with torch.no_grad():
  43. for data, targets, _ in data_loader:
  44. data = data.to(device)
  45. outputs = model(data)
  46. outputs = set_same_table_head(data, outputs)
  47. result = torch.zeros_like(outputs)
  48. result[outputs >= 0.5] = 1
  49. result = result.numpy().tolist()
  50. result_list += result
  51. # 设置一些特定的表头
  52. for i in range(len(result_list)):
  53. row = table_text_list[i]
  54. row_label = result_list[i]
  55. result_list[i] = set_label(row, row_label)
  56. return result_list