otr_interface.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import base64
  2. import multiprocessing as mp
  3. import os
  4. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  5. import sys
  6. import time
  7. import traceback
  8. from multiprocessing.context import Process
  9. from flask import Flask, jsonify
  10. from flask import request
  11. import logging
  12. from table_line import *
  13. app = Flask(__name__)
  14. app.config['JSON_AS_ASCII'] = False
  15. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  16. logger = logging.getLogger(__name__)
  17. def log(msg):
  18. """
  19. @summary:打印信息
  20. """
  21. logger.info(msg)
  22. @app.route('/otr', methods=['POST'])
  23. def otr():
  24. start_time = time.time()
  25. if request.method == "POST":
  26. # 检测是否有数据
  27. if not request.data:
  28. return 'no data'
  29. img_data = base64.b64decode(request.data)
  30. points_and_lines = pool.apply(table_detect, (img_data,))
  31. return points_and_lines
  32. flag = 0
  33. model_path = "models/table-line.h5"
  34. def table_detect(img_data):
  35. print("child process ", os.getpid())
  36. start_time = time.time()
  37. try:
  38. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  39. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  40. # cv2.imwrite("111111.jpg", img)
  41. # 将bgr转为rbg
  42. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  43. # 选择与图片最接近分辨率,以防失真
  44. best_h, best_w = get_best_predict_size(img)
  45. # 调用模型
  46. # rows, cols = table_line(image_np, otr_model)
  47. rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  48. if not rows or not cols:
  49. print("points", 0, "split_lines", 0, "bboxes", 0)
  50. return {"points": str([]), "split_lines": str([]),
  51. "bboxes": str([]), "outline_points": str([])}
  52. # 处理结果
  53. # 计算交点、分割线
  54. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  55. split_lines, split_y = get_split_line(points, cols, image_np)
  56. # 计算交点所在行列,剔除相近交点
  57. row_point_list = get_points_row(points, split_y, 0)
  58. col_point_list = get_points_col(points, split_y, 0)
  59. points = delete_close_points(points, row_point_list, col_point_list)
  60. # 修复边框
  61. new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
  62. split_y)
  63. if new_rows or new_cols:
  64. # 连接至补线的延长线
  65. if long_rows:
  66. rows = long_rows
  67. if long_cols:
  68. cols = long_cols
  69. # 新的补线
  70. if new_rows:
  71. rows += new_rows
  72. if new_cols:
  73. cols += new_cols
  74. # 修复边框后重新计算交点、分割线
  75. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  76. split_lines, split_y = get_split_line(points, cols, image_np)
  77. # 计算交点所在行列,剔除相近交点
  78. row_point_list = get_points_row(points, split_y, 0)
  79. col_point_list = get_points_col(points, split_y, 0)
  80. points = delete_close_points(points, row_point_list, col_point_list)
  81. row_point_list = get_points_row(points, split_y)
  82. col_point_list = get_points_col(points, split_y)
  83. # 获取bbox 单元格
  84. bboxes = get_bbox(image_np, points, split_y, row_point_list, col_point_list)
  85. # 获取每个表格的左上右下两个点
  86. outline_points = get_outline_point(points, split_y)
  87. if bboxes:
  88. print("bboxes number", len(bboxes))
  89. print("bboxes", bboxes)
  90. else:
  91. print("bboxes number", "None")
  92. print("use time: ", time.time()-start_time)
  93. return {"points": str(points), "split_lines": str(split_lines),
  94. "bboxes": str(bboxes), "outline_points": str(outline_points)}
  95. except Exception as e:
  96. print("otr_interface cannot detected table!", traceback.print_exc())
  97. print("points", 0, "split_lines", 0, "bboxes", 0)
  98. return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  99. "outline_points": str([])}
  100. class MyProcess(Process):
  101. def __init__(self):
  102. global otr_model
  103. otr_model = table_net((None, None, 3), 2)
  104. otr_model.load_weights(model_path)
  105. # 自己写__init__(self)会将父类的__init__覆盖,为了不丢失父类的一些属性,需要用super()加载
  106. super().__init__()
  107. # run()是Process类专门留出来让你重写的接口函数
  108. # def run(self):
  109. pool = mp.Pool(processes=1, initializer=MyProcess, initargs=())
  110. otr_model = 0
  111. if __name__ == '__main__':
  112. app.run(host='0.0.0.0', port=15017, threaded=True, debug=True)
  113. log("OTR running")
  114. # with open("开标记录表3_page_0.png", "rb") as f:
  115. # temp_img = f.read()
  116. # otr_model = table_net((None, None, 3), 2)
  117. # otr_model.load_weights(model_path)
  118. # table_detect(temp_img)