import base64 import multiprocessing as mp import os # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' import sys import time import traceback from multiprocessing.context import Process import multiprocessing from flask import Flask, jsonify from flask import request import logging # from table_line import * import cv2 import numpy as np import tensorflow as tf from table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, get_points_col, \ delete_close_points, fix_outline, get_bbox, get_outline_point, table_net app = Flask(__name__) app.config['JSON_AS_ASCII'] = False logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def log(msg): """ @summary:打印信息 """ logger.info(msg) # @app.before_first_request # def init(): @app.route('/otr', methods=['POST']) def otr(): start_time = time.time() if request.method == "POST": # 检测是否有数据 if not request.data: return 'no data' print("child process ", os.getpid(), request.data[-6:]) img_data = base64.b64decode(request.data) # points_and_lines = pool.apply(table_detect, (img_data,)) points_and_lines = table_detect(img_data) return points_and_lines flag = 0 model_path = "models/table-line.h5" def table_detect(img_data): start_time = time.time() try: # 二进制数据流转np.ndarray [np.uint8: 8位像素] img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # cv2.imwrite("111111.jpg", img) # 将bgr转为rbg image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # 选择与图片最接近分辨率,以防失真 best_h, best_w = get_best_predict_size(img) # 调用模型 # rows, cols = table_line(image_np, otr_model) rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5) if not rows or not cols: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} # 处理结果 # 计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) if not points: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} split_lines, split_y = get_split_line(points, cols, image_np) # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 0) col_point_list = get_points_col(points, split_y, 0) points = delete_close_points(points, row_point_list, col_point_list) # 修复边框 new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points, split_y) if new_rows or new_cols: # 连接至补线的延长线 if long_rows: rows = long_rows if long_cols: cols = long_cols # 新的补线 if new_rows: rows += new_rows if new_cols: cols += new_cols # 修复边框后重新计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) split_lines, split_y = get_split_line(points, cols, image_np) # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 0) col_point_list = get_points_col(points, split_y, 0) points = delete_close_points(points, row_point_list, col_point_list) row_point_list = get_points_row(points, split_y) col_point_list = get_points_col(points, split_y) # 获取bbox 单元格 bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols) # 获取每个表格的左上右下两个点 outline_points = get_outline_point(points, split_y) if bboxes: print("bboxes number", len(bboxes)) # print("bboxes", bboxes) else: print("bboxes number", "None") print("use time: ", time.time()-start_time) return {"points": str(points), "split_lines": str(split_lines), "bboxes": str(bboxes), "outline_points": str(outline_points)} except Exception as e: print("otr_interface cannot detected table!", traceback.print_exc()) print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([])} otr_model = table_net((None, None, 3), 2) otr_model.load_weights(model_path) if __name__ == '__main__': if len(sys.argv) == 2: port = int(sys.argv[1]) else: port = 15017 app.run(host='0.0.0.0', port=port, threaded=False, debug=False) log("OTR running "+str(port)) # print("init model...") # g1 = tf.Graph() # tf.compat.v1.disable_eager_execution() # sess1 = tf.compat.v1.Session(graph=g1) # with sess1.as_default(): # with g1.as_default(): # _model = table_net((None, None, 3), 2) # _model.load_weights(model_path) # otr_model_list[0] = _model # # g2 = tf.Graph() # tf.compat.v1.disable_eager_execution() # sess2 = tf.compat.v1.Session(graph=g2) # with sess2.as_default(): # with g2.as_default(): # _model = table_net((None, None, 3), 2) # _model.load_weights(model_path) # otr_model_list[1] = _model # # otr_graph_list[0] = g1 # otr_graph_list[1] = g2 # print("init finish") # # p = MyProcess(15017) # p.start() # # p1 = MyProcess(15018) # p1.start() # p.join() # p1.join() # otr_model = table_net((None, None, 3), 2) # otr_model.load_weights(model_path) # # start_interface(15017)