import numpy as np import cv2 from table_postprocess import get_line_from_binary_image def predict(img, model, size, row_prob=0.5, col_prob=0.5): """ 图片经过模型提取表格线,h_line横线,v_line竖线 :param img: numpy格式image :param model: OTR表格线提取模型 :param size: resize后的image size :param h_line_prob: h_line置信度 :param v_line_prob: v_line置信度 :return: 表格线list """ # resize width, height = size img_new = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) # 模型预测 pred = model.predict(np.array([img_new])) pred = pred[0] # 横线预测结果 row_pred = pred[..., 0] > row_prob row_pred = row_pred.astype(np.uint8) # 竖线预测结果 col_pred = pred[..., 1] > col_prob col_pred = col_pred.astype(np.uint8) print("row_pred.shape", row_pred.shape) cv2.imshow("row predict", row_pred*255) cv2.imshow("col predict", col_pred*255) cv2.waitKey(0) cv2.imshow("predict", (col_pred+row_pred)*255) cv2.waitKey(0) # 根据预测结果获取线条 rows = get_line_from_binary_image(row_pred, point_value=1, axis=0) cols = get_line_from_binary_image(col_pred, point_value=1, axis=1) return rows, cols, img_new