post_process.py 768 B

1234567891011121314151617181920212223242526
  1. def table_post_process(table_text_list, predict_result, threshold=0.5):
  2. predict_result = predict_result.tolist()
  3. predict_list = []
  4. for i in range(0, len(predict_result)):
  5. predict = predict_result[i][0]
  6. if predict >= threshold:
  7. predict = "1"
  8. else:
  9. predict = "0"
  10. predict_list.append(predict)
  11. table_label_list = []
  12. i = 0
  13. for row in table_text_list:
  14. row_label = []
  15. for col in row:
  16. row_label.append(predict_list[i])
  17. i += 1
  18. table_label_list.append(row_label)
  19. if len(table_label_list) != len(table_text_list):
  20. print("table_post_process 输出label维度与text不一致!")
  21. table_label_list = []
  22. return table_label_list