otr_interface.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import base64
  2. import multiprocessing as mp
  3. import os
  4. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  5. import sys
  6. import time
  7. import traceback
  8. from multiprocessing.context import Process
  9. import multiprocessing
  10. from flask import Flask, jsonify
  11. from flask import request
  12. import logging
  13. # from table_line import *
  14. import cv2
  15. import numpy as np
  16. import tensorflow as tf
  17. from table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, get_points_col, \
  18. delete_close_points, fix_outline, get_bbox, get_outline_point, table_net
  19. app = Flask(__name__)
  20. app.config['JSON_AS_ASCII'] = False
  21. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  22. logger = logging.getLogger(__name__)
  23. def log(msg):
  24. """
  25. @summary:打印信息
  26. """
  27. logger.info(msg)
  28. # @app.before_first_request
  29. # def init():
  30. @app.route('/otr', methods=['POST'])
  31. def otr():
  32. start_time = time.time()
  33. if request.method == "POST":
  34. # 检测是否有数据
  35. if not request.data:
  36. return 'no data'
  37. print("child process ", os.getpid(), request.data[-6:])
  38. img_data = base64.b64decode(request.data)
  39. # points_and_lines = pool.apply(table_detect, (img_data,))
  40. points_and_lines = table_detect(img_data)
  41. return points_and_lines
  42. flag = 0
  43. model_path = "models/table-line.h5"
  44. def table_detect(img_data):
  45. start_time = time.time()
  46. try:
  47. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  48. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  49. # cv2.imwrite("111111.jpg", img)
  50. # 将bgr转为rbg
  51. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  52. # 选择与图片最接近分辨率,以防失真
  53. best_h, best_w = get_best_predict_size(img)
  54. # 调用模型
  55. # rows, cols = table_line(image_np, otr_model)
  56. rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  57. if not rows or not cols:
  58. print("points", 0, "split_lines", 0, "bboxes", 0)
  59. return {"points": str([]), "split_lines": str([]),
  60. "bboxes": str([]), "outline_points": str([])}
  61. # 处理结果
  62. # 计算交点、分割线
  63. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  64. if not points:
  65. print("points", 0, "split_lines", 0, "bboxes", 0)
  66. return {"points": str([]), "split_lines": str([]),
  67. "bboxes": str([]), "outline_points": str([])}
  68. split_lines, split_y = get_split_line(points, cols, image_np)
  69. # 计算交点所在行列,剔除相近交点
  70. row_point_list = get_points_row(points, split_y, 0)
  71. col_point_list = get_points_col(points, split_y, 0)
  72. points = delete_close_points(points, row_point_list, col_point_list)
  73. # 修复边框
  74. new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
  75. split_y)
  76. if new_rows or new_cols:
  77. # 连接至补线的延长线
  78. if long_rows:
  79. rows = long_rows
  80. if long_cols:
  81. cols = long_cols
  82. # 新的补线
  83. if new_rows:
  84. rows += new_rows
  85. if new_cols:
  86. cols += new_cols
  87. # 修复边框后重新计算交点、分割线
  88. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  89. split_lines, split_y = get_split_line(points, cols, image_np)
  90. # 计算交点所在行列,剔除相近交点
  91. row_point_list = get_points_row(points, split_y, 0)
  92. col_point_list = get_points_col(points, split_y, 0)
  93. points = delete_close_points(points, row_point_list, col_point_list)
  94. row_point_list = get_points_row(points, split_y)
  95. col_point_list = get_points_col(points, split_y)
  96. # 获取bbox 单元格
  97. bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
  98. # 获取每个表格的左上右下两个点
  99. outline_points = get_outline_point(points, split_y)
  100. if bboxes:
  101. print("bboxes number", len(bboxes))
  102. # print("bboxes", bboxes)
  103. else:
  104. print("bboxes number", "None")
  105. print("use time: ", time.time()-start_time)
  106. return {"points": str(points), "split_lines": str(split_lines),
  107. "bboxes": str(bboxes), "outline_points": str(outline_points)}
  108. except Exception as e:
  109. print("otr_interface cannot detected table!", traceback.print_exc())
  110. print("points", 0, "split_lines", 0, "bboxes", 0)
  111. return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  112. "outline_points": str([])}
  113. otr_model = table_net((None, None, 3), 2)
  114. otr_model.load_weights(model_path)
  115. if __name__ == '__main__':
  116. if len(sys.argv) == 2:
  117. port = int(sys.argv[1])
  118. else:
  119. port = 15017
  120. app.run(host='0.0.0.0', port=port, threaded=False, debug=False)
  121. log("OTR running "+str(port))
  122. # print("init model...")
  123. # g1 = tf.Graph()
  124. # tf.compat.v1.disable_eager_execution()
  125. # sess1 = tf.compat.v1.Session(graph=g1)
  126. # with sess1.as_default():
  127. # with g1.as_default():
  128. # _model = table_net((None, None, 3), 2)
  129. # _model.load_weights(model_path)
  130. # otr_model_list[0] = _model
  131. #
  132. # g2 = tf.Graph()
  133. # tf.compat.v1.disable_eager_execution()
  134. # sess2 = tf.compat.v1.Session(graph=g2)
  135. # with sess2.as_default():
  136. # with g2.as_default():
  137. # _model = table_net((None, None, 3), 2)
  138. # _model.load_weights(model_path)
  139. # otr_model_list[1] = _model
  140. #
  141. # otr_graph_list[0] = g1
  142. # otr_graph_list[1] = g2
  143. # print("init finish")
  144. #
  145. # p = MyProcess(15017)
  146. # p.start()
  147. #
  148. # p1 = MyProcess(15018)
  149. # p1.start()
  150. # p.join()
  151. # p1.join()
  152. # otr_model = table_net((None, None, 3), 2)
  153. # otr_model.load_weights(model_path)
  154. #
  155. # start_interface(15017)