post_process.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. def table_post_process(table_text_list, predict_result, threshold=0.5, is_list=False):
  2. if not is_list:
  3. predict_result = predict_result.tolist()
  4. predict_list = []
  5. for i in range(0, len(predict_result)):
  6. predict = predict_result[i][0]
  7. if predict >= threshold:
  8. predict = "1"
  9. else:
  10. predict = "0"
  11. predict_list.append(predict)
  12. table_label_list = []
  13. i = 0
  14. for row in table_text_list:
  15. row_label = []
  16. for col in row:
  17. row_label.append(predict_list[i])
  18. i += 1
  19. table_label_list.append(row_label)
  20. if len(table_label_list) != len(table_text_list):
  21. print("table_post_process 输出label维度与text不一致!")
  22. table_label_list = []
  23. return table_label_list
  24. def table_post_process_2(table_text_list, predict_result, threshold=0.5):
  25. predict_result = predict_result.tolist()[0]
  26. predict_list = []
  27. for row in predict_result:
  28. new_row = []
  29. for col in row:
  30. if col >= threshold:
  31. new_row.append("1")
  32. else:
  33. new_row.append("0")
  34. predict_list.append(new_row)
  35. if len(predict_list) != len(predict_result):
  36. print("table_post_process 输出label维度与text不一致!")
  37. predict_list = []
  38. return predict_list