table_predict.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import numpy as np
  2. import cv2
  3. from table_postprocess import get_line_from_binary_image
  4. def predict(img, model, size, row_prob=0.5, col_prob=0.5):
  5. """
  6. 图片经过模型提取表格线,h_line横线,v_line竖线
  7. :param img: numpy格式image
  8. :param model: OTR表格线提取模型
  9. :param size: resize后的image size
  10. :param h_line_prob: h_line置信度
  11. :param v_line_prob: v_line置信度
  12. :return: 表格线list
  13. """
  14. # resize
  15. width, height = size
  16. img_new = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
  17. # 模型预测
  18. pred = model.predict(np.array([img_new]))
  19. pred = pred[0]
  20. # 横线预测结果
  21. row_pred = pred[..., 0] > row_prob
  22. row_pred = row_pred.astype(np.uint8)
  23. # 竖线预测结果
  24. col_pred = pred[..., 1] > col_prob
  25. col_pred = col_pred.astype(np.uint8)
  26. print("row_pred.shape", row_pred.shape)
  27. cv2.imshow("row predict", row_pred*255)
  28. cv2.imshow("col predict", col_pred*255)
  29. cv2.waitKey(0)
  30. cv2.imshow("predict", (col_pred+row_pred)*255)
  31. cv2.waitKey(0)
  32. # 根据预测结果获取线条
  33. rows = get_line_from_binary_image(row_pred, point_value=1, axis=0)
  34. cols = get_line_from_binary_image(col_pred, point_value=1, axis=1)
  35. return rows, cols, img_new