123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- import base64
- import multiprocessing as mp
- import os
- # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
- import sys
- import time
- import traceback
- from multiprocessing.context import Process
- from flask import Flask, jsonify
- from flask import request
- import logging
- from table_line import *
- app = Flask(__name__)
- app.config['JSON_AS_ASCII'] = False
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- def log(msg):
- """
- @summary:打印信息
- """
- logger.info(msg)
- @app.route('/otr', methods=['POST'])
- def otr():
- start_time = time.time()
- if request.method == "POST":
- # 检测是否有数据
- if not request.data:
- return 'no data'
- img_data = base64.b64decode(request.data)
- points_and_lines = pool.apply(table_detect, (img_data,))
- return points_and_lines
- flag = 0
- model_path = "models/table-line.h5"
- def table_detect(img_data):
- print("child process ", os.getpid())
- start_time = time.time()
- try:
- # 二进制数据流转np.ndarray [np.uint8: 8位像素]
- img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
- # cv2.imwrite("111111.jpg", img)
- # 将bgr转为rbg
- image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
- # 选择与图片最接近分辨率,以防失真
- best_h, best_w = get_best_predict_size(img)
- # 调用模型
- # rows, cols = table_line(image_np, otr_model)
- rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
- if not rows or not cols:
- print("points", 0, "split_lines", 0, "bboxes", 0)
- return {"points": str([]), "split_lines": str([]),
- "bboxes": str([]), "outline_points": str([])}
- # 处理结果
- # 计算交点、分割线
- points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
- split_lines, split_y = get_split_line(points, cols, image_np)
- # 计算交点所在行列,剔除相近交点
- row_point_list = get_points_row(points, split_y, 0)
- col_point_list = get_points_col(points, split_y, 0)
- points = delete_close_points(points, row_point_list, col_point_list)
- # 修复边框
- new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
- split_y)
- if new_rows or new_cols:
- # 连接至补线的延长线
- if long_rows:
- rows = long_rows
- if long_cols:
- cols = long_cols
- # 新的补线
- if new_rows:
- rows += new_rows
- if new_cols:
- cols += new_cols
- # 修复边框后重新计算交点、分割线
- points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
- split_lines, split_y = get_split_line(points, cols, image_np)
- # 计算交点所在行列,剔除相近交点
- row_point_list = get_points_row(points, split_y, 0)
- col_point_list = get_points_col(points, split_y, 0)
- points = delete_close_points(points, row_point_list, col_point_list)
- row_point_list = get_points_row(points, split_y)
- col_point_list = get_points_col(points, split_y)
- # 获取bbox 单元格
- bboxes = get_bbox(image_np, points, split_y, row_point_list, col_point_list)
- # 获取每个表格的左上右下两个点
- outline_points = get_outline_point(points, split_y)
- if bboxes:
- print("bboxes number", len(bboxes))
- print("bboxes", bboxes)
- else:
- print("bboxes number", "None")
- print("use time: ", time.time()-start_time)
- return {"points": str(points), "split_lines": str(split_lines),
- "bboxes": str(bboxes), "outline_points": str(outline_points)}
- except Exception as e:
- print("otr_interface cannot detected table!", traceback.print_exc())
- print("points", 0, "split_lines", 0, "bboxes", 0)
- return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
- "outline_points": str([])}
- class MyProcess(Process):
- def __init__(self):
- global otr_model
- otr_model = table_net((None, None, 3), 2)
- otr_model.load_weights(model_path)
- # 自己写__init__(self)会将父类的__init__覆盖,为了不丢失父类的一些属性,需要用super()加载
- super().__init__()
- # run()是Process类专门留出来让你重写的接口函数
- # def run(self):
- pool = mp.Pool(processes=1, initializer=MyProcess, initargs=())
- otr_model = 0
- if __name__ == '__main__':
- app.run(host='0.0.0.0', port=15017, threaded=True, debug=True)
- log("OTR running")
- # with open("开标记录表3_page_0.png", "rb") as f:
- # temp_img = f.read()
- # otr_model = table_net((None, None, 3), 2)
- # otr_model.load_weights(model_path)
- # table_detect(temp_img)
|