post_process.py 842 B

123456789101112131415161718192021222324252627
  1. def table_post_process(table_text_list, predict_result, threshold=0.5):
  2. predict_result = predict_result.tolist()
  3. predict_list = []
  4. for i in range(0, len(predict_result), 2):
  5. row_predict = predict_result[i][1]
  6. col_predict = predict_result[i+1][1]
  7. if row_predict + col_predict >= 2 * threshold:
  8. predict = "1"
  9. else:
  10. predict = "0"
  11. predict_list.append(predict)
  12. table_label_list = []
  13. i = 0
  14. for row in table_text_list:
  15. row_label = []
  16. for col in row:
  17. row_label.append(predict_list[i])
  18. i += 1
  19. table_label_list.append(row_label)
  20. if len(table_label_list) != len(table_text_list):
  21. print("table_post_process 输出label维度与text不一致!")
  22. table_label_list = []
  23. return table_label_list