import base64 import multiprocessing as mp import os # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' import sys import time import traceback from multiprocessing.context import Process from flask import Flask, jsonify from flask import request import logging from table_line import * app = Flask(__name__) app.config['JSON_AS_ASCII'] = False logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def log(msg): """ @summary:打印信息 """ logger.info(msg) @app.route('/otr', methods=['POST']) def otr(): start_time = time.time() if request.method == "POST": # 检测是否有数据 if not request.data: return 'no data' img_data = base64.b64decode(request.data) points_and_lines = pool.apply(table_detect, (img_data,)) return points_and_lines flag = 0 model_path = "models/table-line.h5" def table_detect(img_data): print("child process ", os.getpid()) start_time = time.time() try: # 二进制数据流转np.ndarray [np.uint8: 8位像素] img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # cv2.imwrite("111111.jpg", img) # 将bgr转为rbg image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # 选择与图片最接近分辨率,以防失真 best_h, best_w = get_best_predict_size(img) # 调用模型 # rows, cols = table_line(image_np, otr_model) rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5) if not rows or not cols: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} # 处理结果 # 计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) split_lines, split_y = get_split_line(points, cols, image_np) # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 0) col_point_list = get_points_col(points, split_y, 0) points = delete_close_points(points, row_point_list, col_point_list) # 修复边框 new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points, split_y) if new_rows or new_cols: # 连接至补线的延长线 if long_rows: rows = long_rows if long_cols: cols = long_cols # 新的补线 if new_rows: rows += new_rows if new_cols: cols += new_cols # 修复边框后重新计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) split_lines, split_y = get_split_line(points, cols, image_np) # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 0) col_point_list = get_points_col(points, split_y, 0) points = delete_close_points(points, row_point_list, col_point_list) row_point_list = get_points_row(points, split_y) col_point_list = get_points_col(points, split_y) # 获取bbox 单元格 bboxes = get_bbox(image_np, points, split_y, row_point_list, col_point_list) # 获取每个表格的左上右下两个点 outline_points = get_outline_point(points, split_y) if bboxes: print("bboxes number", len(bboxes)) print("bboxes", bboxes) else: print("bboxes number", "None") print("use time: ", time.time()-start_time) return {"points": str(points), "split_lines": str(split_lines), "bboxes": str(bboxes), "outline_points": str(outline_points)} except Exception as e: print("otr_interface cannot detected table!", traceback.print_exc()) print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} class MyProcess(Process): def __init__(self): global otr_model otr_model = table_net((None, None, 3), 2) otr_model.load_weights(model_path) # 自己写__init__(self)会将父类的__init__覆盖,为了不丢失父类的一些属性,需要用super()加载 super().__init__() # run()是Process类专门留出来让你重写的接口函数 # def run(self): pool = mp.Pool(processes=1, initializer=MyProcess, initargs=()) otr_model = 0 if __name__ == '__main__': app.run(host='0.0.0.0', port=15017, threaded=True, debug=True) log("OTR running") # with open("开标记录表3_page_0.png", "rb") as f: # temp_img = f.read() # otr_model = table_net((None, None, 3), 2) # otr_model.load_weights(model_path) # table_detect(temp_img)