import base64 import json import multiprocessing as mp import os # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' import sys sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../") import time import traceback from multiprocessing.context import Process import multiprocessing import logging # from table_line import * import cv2 import numpy as np import tensorflow as tf from flask import Flask, request from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \ get_points_col, \ delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \ fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline from format_convert import _global # 接口配置 app = Flask(__name__) @app.route('/otr', methods=['POST']) def _otr(): _global._init() _global.update({"port": globals().get("port")}) log("into otr_interface _otr") try: if not request.form: log("otr no data!") return json.dumps({"list_line": str([-9])}) otr_model = globals().get("global_otr_model") if otr_model is None: otr_model = OtrModels().get_model() globals().update({"global_otr_model": otr_model}) data = request.form.get("data") is_from_pdf = request.form.get("is_from_pdf") img_data = base64.b64decode(data) _md5 = get_md5_from_bytes(img_data)[0] _global.update({"md5": _md5}) if is_from_pdf: list_lines = line_detect(img_data, otr_model, prob=0.2) else: list_lines = line_detect(img_data, otr_model, prob=0.5) return json.dumps(list_lines) except TimeoutError: return json.dumps({"list_line": str([-5])}) except: traceback.print_exc() return json.dumps({"list_line": str([-1])}) def otr(data, otr_model, is_from_pdf): try: img_data = base64.b64decode(data) # points_and_lines = pool.apply(table_detect, (img_data,)) if is_from_pdf: list_lines = line_detect(img_data, otr_model, prob=0.2) else: list_lines = line_detect(img_data, otr_model, prob=0.5) return list_lines except TimeoutError: raise TimeoutError flag = 0 # model_path = "models/table-line.h5" def table_detect2(img_data, otr_model): log("into otr_interface table_detect") start_time = time.time() try: start_time1 = time.time() # 二进制数据流转np.ndarray [np.uint8: 8位像素] img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # log("into otr_interface table_detect 1") # cv2.imwrite("111111.jpg", img) # 将bgr转为rbg image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # log("into otr_interface table_detect 2") # 选择与图片最接近分辨率,以防失真 # best_h, best_w = get_best_predict_size(img) print("image_np.shape", image_np.shape) best_h, best_w, _ = image_np.shape log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") # 调用模型 # rows, cols = table_line(image_np, otr_model) rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5) start_time1 = time.time() if not rows or not cols: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([]), "lines": str([])} # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (255, 0, 0), 2) # cv2.imshow("rows-cols1", img) # cv2.waitKey(0) # 处理结果 # 合并错开线 rows = merge_line(rows, axis=0) cols = merge_line(cols, axis=1) # 计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # log("into otr_interface table_detect 5") if not points: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([]), "lines": str([])} # 清掉外围的没用的线 rows, cols = delete_outline(rows, cols, points) split_lines, split_y = get_split_line(points, cols, image_np) # log("into otr_interface table_detect 6") # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 5) col_point_list = get_points_col(points, split_y, 5) # log("into otr_interface table_detect 7") points = delete_close_points(points, row_point_list, col_point_list) # log("into otr_interface table_detect 8") # 查看是否正确输出点 # for p in points: # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255)) # cv2.imshow("points", img) # cv2.waitKey(0) # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (0, 255, 0), 2) # cv2.imshow("rows-cols0", img) # cv2.waitKey(0) # 修复边框 new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points, split_y) # print(new_cols, new_rows) if new_rows or new_cols: # 连接至补线的延长线 if long_rows: rows = long_rows if long_cols: cols = long_cols # 新的补线 if new_rows: rows += new_rows if new_cols: cols += new_cols # 修复边框后重新计算交点、分割线 points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # log("into otr_interface table_detect 10") split_lines, split_y = get_split_line(points, cols, image_np) # 计算交点所在行列,剔除相近交点 row_point_list = get_points_row(points, split_y, 0) col_point_list = get_points_col(points, split_y, 0) # log("into otr_interface table_detect 11") points = delete_close_points(points, row_point_list, col_point_list) # row_point_list = get_points_row(points, split_y) # col_point_list = get_points_col(points, split_y) # log("into otr_interface table_detect 12") # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (255, 0, 0), 2) # cv2.imshow("rows-cols1", img) # cv2.waitKey(0) # 修复表格4个角 rows, cols = fix_corner(rows, cols, split_y) points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # row_point_list = get_points_row(points, split_y, 5) # col_point_list = get_points_col(points, split_y, 5) # print("row_point_list", row_point_list) # print("col_point_list", col_point_list) # 修复内部缺线 points = fix_inner(rows, cols, points, split_y) if not points: print("points", 0, "split_lines", 0, "bboxes", 0) return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([]), "lines": str([])} row_point_list = get_points_row(points, split_y, 5) col_point_list = get_points_col(points, split_y, 5) # 查看是否正确输出点 # for p in points: # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3) # cv2.imshow("points fix", img) # cv2.waitKey(0) # 查看是否正确输出rows,cols # for line in rows+cols: # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), # (255, 0, 0), 2) # cv2.imshow("rows-cols2", img) # cv2.waitKey(0) # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败 # rows = points_to_line(row_point_list, axis=0) # cols = points_to_line(col_point_list, axis=1) # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1])) # row_point_list = get_points_row(points, split_y, 0) # col_point_list = get_points_col(points, split_y, 0) # 获取bbox 单元格 bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols) # log("into otr_interface table_detect 13") # 删除包含bbox if bboxes: bboxes = delete_contain_bbox(bboxes) # 查看是否能输出正确框 # for box in bboxes: # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3) # cv2.imshow("bbox", img) # cv2.waitKey(0) # 补充连续框 # if bboxes: # bboxes = add_continue_bbox(bboxes) # # # 删除包含bbox # bboxes = delete_contain_bbox(bboxes) # 查看是否能输出正确框 # cv2.namedWindow('bbox', 0) # for box in bboxes: # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3) # cv2.imshow("bbox", img) # cv2.waitKey(0) # 查看是否正确输出点 # cv2.namedWindow('points', 0) # for p in points: # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255)) # cv2.imshow("points", img) # cv2.waitKey(0) # 查看是否正确输出区域分割线 # cv2.namedWindow('split_lines', 0) # for line in split_lines: # cv2.line(img, line[0], line[1], (0, 0, 255), 2) # cv2.imshow("split_lines", img) # cv2.waitKey(0) # 获取每个表格的左上右下两个点 outline_points = get_outline_point(points, split_y) # log("into otr_interface table_detect 14") if bboxes: print("bboxes number", len(bboxes)) # print("bboxes", bboxes) else: print("bboxes number", "None") log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") log("otr finish: " + str(round(float(time.time()-start_time1), 4)) + "s") return {"points": str(points), "split_lines": str(split_lines), "bboxes": str(bboxes), "outline_points": str(outline_points), "lines": str(rows+cols)} except TimeoutError: raise TimeoutError except Exception as e: log("otr_interface cannot detected table!") print("otr_interface cannot detected table!", traceback.print_exc()) print("points", 0, "split_lines", 0, "bboxes", 0) log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") return {"points": str([]), "split_lines": str([]), "bboxes": str([]), "outline_points": str([]), "lines": str([])} def line_detect(img_data, otr_model, prob=0.2): log("into otr_interface table_detect") start_time = time.time() try: start_time1 = time.time() # 二进制数据流转np.ndarray [np.uint8: 8位像素] img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR) # log("into otr_interface table_detect 1") # cv2.imwrite("111111.jpg", img) # 将bgr转为rbg image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # log("into otr_interface table_detect 2") # 选择与图片最接近分辨率,以防失真 # best_h, best_w = get_best_predict_size(img) log("image_np.shape" + str(image_np.shape)) best_h, best_w, _ = image_np.shape log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") # 调用模型 # rows, cols = table_line(image_np, otr_model) start_time1 = time.time() list_line = table_line(image_np, otr_model, size=(best_w, best_h), prob=prob) log("otr finish " + str(round(float(time.time()-start_time1), 4)) + "s") return {"list_line": str(list_line)} except TimeoutError: raise TimeoutError except Exception as e: log("otr_interface cannot detected table!") print("otr_interface cannot detected table!", traceback.print_exc()) log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s") return {"list_line": str([])} class OtrModels: def __init__(self): # python文件所在目录 _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) model_path = _dir + "/models/table-line.h5" self.otr_model = table_net((None, None, 3), 2) self.otr_model.load_weights(model_path) def get_model(self): return self.otr_model def test_otr_model(): file_path = "C:/Users/Administrator/Desktop/error2.png" with open(file_path, "rb") as f: file_bytes = f.read() file_base64 = base64.b64encode(file_bytes) file_json = {"data": file_base64, "is_from_pdf": False} _url = "http://192.168.2.103:18000/otr" r = json.loads(request_post(_url, file_json)) print(r) # otr_model = table_net((None, None, 3), 2) # otr_model.load_weights(model_path) if __name__ == '__main__': if len(sys.argv) == 2: port = int(sys.argv[1]) elif len(sys.argv) == 3: port = int(sys.argv[1]) using_gpu_index = int(sys.argv[2]) else: port = 18000 using_gpu_index = 0 _global._init() _global.update({"port": str(port)}) globals().update({"port": str(port)}) # 日志格式设置 # ip = get_intranet_ip() # logging.basicConfig(level=logging.INFO, # format='%(asctime)s - %(name)s - %(levelname)s - ' # + ip + ' - ' + str(port) + ' - %(message)s') # 限制tensorflow显存 memory_limit_scale = 0.3 os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index) os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648) os.environ['CUDA_CACHE_DISABLE'] = str(0) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=memory_limit_scale) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False) # app.run(port=port) log("OTR running "+str(port)) # test_otr_model() # print(json.dumps([-2])) # otr_model = OtrModels().get_model() # otr("11", otr_model)