1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import numpy as np
- import cv2
- from table_postprocess import get_line_from_binary_image
- def predict(img, model, size, row_prob=0.5, col_prob=0.5):
- """
- 图片经过模型提取表格线,h_line横线,v_line竖线
- :param img: numpy格式image
- :param model: OTR表格线提取模型
- :param size: resize后的image size
- :param h_line_prob: h_line置信度
- :param v_line_prob: v_line置信度
- :return: 表格线list
- """
- # resize
- width, height = size
- img_new = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
- # 模型预测
- pred = model.predict(np.array([img_new]))
- pred = pred[0]
- # 横线预测结果
- row_pred = pred[..., 0] > row_prob
- row_pred = row_pred.astype(np.uint8)
- # 竖线预测结果
- col_pred = pred[..., 1] > col_prob
- col_pred = col_pred.astype(np.uint8)
- print("row_pred.shape", row_pred.shape)
- cv2.imshow("row predict", row_pred*255)
- cv2.imshow("col predict", col_pred*255)
- cv2.waitKey(0)
- cv2.imshow("predict", (col_pred+row_pred)*255)
- cv2.waitKey(0)
- # 根据预测结果获取线条
- rows = get_line_from_binary_image(row_pred, point_value=1, axis=0)
- cols = get_line_from_binary_image(col_pred, point_value=1, axis=1)
- return rows, cols, img_new
|