post_process.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. def table_post_process(table_text_list, predict_result, threshold=0.5):
  2. predict_result = predict_result.tolist()
  3. predict_list = []
  4. for i in range(0, len(predict_result)):
  5. predict = predict_result[i][0]
  6. if predict >= threshold:
  7. predict = "1"
  8. else:
  9. predict = "0"
  10. predict_list.append(predict)
  11. table_label_list = []
  12. i = 0
  13. for row in table_text_list:
  14. row_label = []
  15. for col in row:
  16. row_label.append(predict_list[i])
  17. i += 1
  18. table_label_list.append(row_label)
  19. if len(table_label_list) != len(table_text_list):
  20. print("table_post_process 输出label维度与text不一致!")
  21. table_label_list = []
  22. return table_label_list
  23. def table_post_process_2(table_text_list, predict_result, threshold=0.5):
  24. predict_result = predict_result.tolist()[0]
  25. predict_list = []
  26. for row in predict_result:
  27. new_row = []
  28. for col in row:
  29. if col >= threshold:
  30. new_row.append("1")
  31. else:
  32. new_row.append("0")
  33. predict_list.append(new_row)
  34. if len(predict_list) != len(predict_result):
  35. print("table_post_process 输出label维度与text不一致!")
  36. predict_list = []
  37. return predict_list