import base64 import multiprocessing as mp import os # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' import sys import time import traceback from multiprocessing.context import Process import multiprocessing import logging # from table_line import * import cv2 import numpy as np import tensorflow as tf from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \ get_points_col, \ delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \ fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def log(msg): """ @summary:打印信息 """ logger.info(msg) def otr(data, otr_model): try: img_data = base64.b64decode(data) # points_and_lines = pool.apply(table_detect, (img_data,)) points_and_lines = table_detect(img_data, otr_model) return points_and_lines except TimeoutError: raise TimeoutError flag = 0 # model_path = "models/table-line.h5" def table_detect(img_data, otr_model): logging.info("into otr_interface table_detect") start_time = time.time() try: start_time1 = time.time() # 二进制数据流转np.ndarray [np.uint8: 8位像素] img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # logging.info("into otr_interface table_detect 1") # cv2.imwrite("111111.jpg", img) # 将bgr转为rbg image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # logging.info("into otr_interface table_detect 2") # 选择与图片最接近分辨率,以防失真 # best_h, best_w = get_best_predict_size(img) print("image_np.shape", image_np.shape) best_h, best_w, _ = image_np.shape logging.info("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") # 调用模型 # rows, cols = table_line(image_np, otr_model) start_time1 = time.time() rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5) logging.info("otr model predict time: " + str(round(float(time.time()-start_time1), 4)) + "s") start_time1 = time.time() if not rows or not cols: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (255, 0, 0), 2) # cv2.imshow("rows-cols1", img) # cv2.waitKey(0) # 处理结果 # 合并错开线 rows = merge_line(rows, axis=0) cols = merge_line(cols, axis=1) # 计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # logging.info("into otr_interface table_detect 5") if not points: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} # 清掉外围的没用的线 rows, cols = delete_outline(rows, cols, points) split_lines, split_y = get_split_line(points, cols, image_np) # logging.info("into otr_interface table_detect 6") # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 5) col_point_list = get_points_col(points, split_y, 5) # logging.info("into otr_interface table_detect 7") points = delete_close_points(points, row_point_list, col_point_list) # logging.info("into otr_interface table_detect 8") # 查看是否正确输出点 # for p in points: # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255)) # cv2.imshow("points", img) # cv2.waitKey(0) # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (0, 255, 0), 2) # cv2.imshow("rows-cols0", img) # cv2.waitKey(0) # 修复边框 new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points, split_y) # print(new_cols, new_rows) if new_rows or new_cols: # 连接至补线的延长线 if long_rows: rows = long_rows if long_cols: cols = long_cols # 新的补线 if new_rows: rows += new_rows if new_cols: cols += new_cols # 修复边框后重新计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # logging.info("into otr_interface table_detect 10") split_lines, split_y = get_split_line(points, cols, image_np) # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 0) col_point_list = get_points_col(points, split_y, 0) # logging.info("into otr_interface table_detect 11") points = delete_close_points(points, row_point_list, col_point_list) # row_point_list = get_points_row(points, split_y) # col_point_list = get_points_col(points, split_y) # logging.info("into otr_interface table_detect 12") # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (255, 0, 0), 2) # cv2.imshow("rows-cols1", img) # cv2.waitKey(0) # 修复表格4个角 rows, cols = fix_corner(rows, cols, split_y) points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # row_point_list = get_points_row(points, split_y, 5) # col_point_list = get_points_col(points, split_y, 5) # print("row_point_list", row_point_list) # print("col_point_list", col_point_list) # 修复内部缺线 points = fix_inner(rows, cols, points, split_y) if not points: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} row_point_list = get_points_row(points, split_y, 5) col_point_list = get_points_col(points, split_y, 5) # 查看是否正确输出点 # for p in points: # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3) # cv2.imshow("points fix", img) # cv2.waitKey(0) # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (255, 0, 0), 2) # cv2.imshow("rows-cols2", img) # cv2.waitKey(0) # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败 # rows = points_to_line(row_point_list, axis=0) # cols = points_to_line(col_point_list, axis=1) # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # row_point_list = get_points_row(points, split_y, 0) # col_point_list = get_points_col(points, split_y, 0) # 获取bbox 单元格 bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols) # logging.info("into otr_interface table_detect 13") # 删除包含bbox if bboxes: bboxes = delete_contain_bbox(bboxes) # 查看是否能输出正确框 # for box in bboxes: # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3) # cv2.imshow("bbox", img) # cv2.waitKey(0) # 补充连续框 # if bboxes: # bboxes = add_continue_bbox(bboxes) # # # 删除包含bbox # bboxes = delete_contain_bbox(bboxes) # 查看是否能输出正确框 # for box in bboxes: # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3) # cv2.imshow("bbox", img) # cv2.waitKey(0) # 查看是否正确输出点 # for p in points: # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255)) # cv2.imshow("points", img) # cv2.waitKey(0) # 查看是否正确输出区域分割线 # for line in split_lines: # cv2.line(img, line[0], line[1], (0, 0, 255), 2) # cv2.imshow("split_lines", img) # cv2.waitKey(0) # 获取每个表格的左上右下两个点 outline_points = get_outline_point(points, split_y) # logging.info("into otr_interface table_detect 14") if bboxes: print("bboxes number", len(bboxes)) # print("bboxes", bboxes) else: print("bboxes number", "None") logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") logging.info("use time: " + str(time.time()-start_time)) return {"points": str(points), "split_lines": str(split_lines), "bboxes": str(bboxes), "outline_points": str(outline_points)} except TimeoutError: raise TimeoutError except Exception as e: logging.info("otr_interface cannot detected table!") print("otr_interface cannot detected table!", traceback.print_exc()) print("points", 0, "split_lines", 0, "bboxes", 0) logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} class OtrModels: def __init__(self): # python文件所在目录 _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) model_path = _dir + "/models/table-line.h5" self.otr_model = table_net((None, None, 3), 2) self.otr_model.load_weights(model_path) def get_model(self): return self.otr_model # otr_model = table_net((None, None, 3), 2) # otr_model.load_weights(model_path) if __name__ == '__main__': # if len(sys.argv) == 2: # port = int(sys.argv[1]) # else: # port = 15017 # app.run(host='0.0.0.0', port=port, threaded=False, debug=False) # log("OTR running "+str(port)) otr_model = OtrModels().get_model() otr("11", otr_model)