otr_interface_new.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import base64
  2. import multiprocessing as mp
  3. import os
  4. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  5. import sys
  6. import time
  7. import traceback
  8. from multiprocessing.context import Process
  9. import multiprocessing
  10. from flask import Flask, jsonify
  11. from flask import request
  12. import logging
  13. # from table_line import *
  14. import cv2
  15. import numpy as np
  16. import tensorflow as tf
  17. from table_line import get_best_predict_size, table_line, table_net, line_fix
  18. from table_predict import predict
  19. from table_postprocess import *
  20. app = Flask(__name__)
  21. app.config['JSON_AS_ASCII'] = False
  22. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  23. logger = logging.getLogger(__name__)
  24. def log(msg):
  25. """
  26. @summary:打印信息
  27. """
  28. logger.info(msg)
  29. @app.route('/otr', methods=['POST'])
  30. def otr():
  31. start_time = time.time()
  32. if request.method == "POST":
  33. # 检测是否有数据
  34. if not request.data:
  35. return 'no data'
  36. img_data = base64.b64decode(request.data)
  37. # points_and_lines = pool.apply(table_detect, (img_data,))
  38. points_and_lines = table_detect(img_data)
  39. return points_and_lines
  40. flag = 0
  41. model_path = "models/table-line.h5"
  42. def table_detect(img_data):
  43. print("child process ", os.getpid())
  44. start_time = time.time()
  45. try:
  46. # p = "test_files/train_463.jpg"
  47. p = "test_files/8.png"
  48. # p = "test_files/无边框1.jpg"
  49. # p = "test_files/part1.png"
  50. # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00e959a0bc9011ebaf5a00163e0ae709" + \
  51. # "\\00e95f7cbc9011ebaf5a00163e0ae709_pdf_page0.png"
  52. # p = "D:\\Project\\format_conversion\\appendix_test\\temp\\00fb3e52bc7e11eb836000163e0ae709" + \
  53. # "\\00fb43acbc7e11eb836000163e0ae709.png"
  54. # p = "test_files/table.jpg"
  55. # p = "C:\\Users\\Administrator\\Desktop\\table2-1.jpg"
  56. # p = r"D:\Project\format_conversion_maxcompute\format_convert\temp" \
  57. # r"\3cf0bfae0a1d11ec9622f02f748bb81b\3cf0e6a40a1d11ecb866f02f748bb81b_pdf_page1.png"
  58. # p = "data_process/create_data/0.jpg"
  59. # p = "../format_conversion/temp/f1fe9c4ac8e511eb81d700163e0857b6/f1fea1e0c8e511eb81d700163e0857b6.png"
  60. # p = "../format_conversion/1.png"
  61. image_np = cv2.imread(p)
  62. # 选择与图片最接近分辨率,以防失真
  63. best_h, best_w = get_best_predict_size(image_np)
  64. # 调用模型
  65. rows, cols, image_np = predict(image_np, otr_model, size=(best_w, best_h))
  66. print("len(cols), len(rows)", len(cols), len(rows))
  67. cols = merge_line(cols, 1)
  68. rows = merge_line(rows, 0)
  69. rows, cols = fix_gap(rows, cols)
  70. print("len(cols), len(rows)", len(cols), len(rows))
  71. print("rows", rows)
  72. print("cols", cols)
  73. # 创建空图
  74. test_img = np.zeros((image_np.shape), np.uint8)
  75. test_img.fill(255)
  76. for box in rows+cols:
  77. cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
  78. cv2.imwrite("predict.jpg", test_img)
  79. cv2.imshow("my post process", test_img)
  80. cv2.waitKey(0)
  81. # test_img = line_fix(test_img)
  82. # cv2.imshow("line fix", test_img)
  83. # cv2.waitKey(0)
  84. # 调用模型
  85. # rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  86. #
  87. # if not rows or not cols:
  88. # print("points", 0, "split_lines", 0, "bboxes", 0)
  89. # return {"points": str([]), "split_lines": str([]),
  90. # "bboxes": str([]), "outline_points": str([])}
  91. #
  92. # print("len(rows)", len(rows))
  93. # print("len(cols)", len(cols))
  94. # 创建空图
  95. # test_img = np.zeros((image_np.shape), np.uint8)
  96. # test_img.fill(255)
  97. # for box in rows+cols:
  98. # cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
  99. # cv2.imshow("his post process", test_img)
  100. # cv2.waitKey(0)
  101. # cv2.imwrite("temp.jpg", test_img)
  102. # 处理结果
  103. # 合并直线
  104. # 计算交点
  105. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  106. # 计算分割线
  107. split_y = get_split_line(cols, (image_np.shape[0], image_np.shape[1]))
  108. print("split_y", split_y)
  109. print("image shape", test_img.shape)
  110. for y in split_y:
  111. cv2.line(test_img, (0, y), (test_img.shape[1], y), (0, 0, 255), 1)
  112. cv2.imshow("split line", test_img)
  113. cv2.waitKey(0)
  114. cv2.imwrite("temp.jpg", test_img)
  115. # 点、线分为多个区域
  116. points_area_list = get_point_area(points, split_y)
  117. rows_area_list = get_line_area(rows, split_y)
  118. cols_area_list = get_line_area(cols, split_y)
  119. print("points_area_list", points_area_list)
  120. print("rows_area_list", rows_area_list)
  121. print("cols_area_list", cols_area_list)
  122. # 对每个区域进行处理
  123. for i in range(len(points_area_list)):
  124. print("i", i)
  125. points_area = points_area_list[i]
  126. rows_area = rows_area_list[i]
  127. cols_area = cols_area_list[i]
  128. if not points_area or not rows_area or not cols_area:
  129. continue
  130. print("len(rows_area)", len(rows_area))
  131. print("len(cols_area)", len(cols_area))
  132. # 修复表格边框
  133. new_rows, new_cols, longer_rows, longer_cols = \
  134. fix_outline_area(rows_area, cols_area, points_area)
  135. if new_rows:
  136. cols_area = longer_cols
  137. rows_area += new_rows
  138. if new_cols:
  139. rows_area = longer_rows
  140. cols_area += new_cols
  141. # 单区域线
  142. test_img = np.zeros((image_np.shape), np.uint8)
  143. test_img.fill(255)
  144. for box in rows_area+cols_area:
  145. cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 0), 1)
  146. for box in new_rows+new_cols:
  147. cv2.line(test_img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)
  148. cv2.imshow("fix_outline", test_img)
  149. cv2.waitKey(0)
  150. # cv2.imwrite("temp.jpg", test_img)
  151. print("use time: ", time.time()-start_time)
  152. return {"points": str([]), "split_lines": str([]),
  153. "bboxes": str([]), "outline_points": str([])}
  154. except Exception as e:
  155. print("otr_interface cannot detected table!", traceback.print_exc())
  156. print("points", 0, "split_lines", 0, "bboxes", 0)
  157. return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  158. "outline_points": str([])}
  159. otr_model = table_net((None, None, 3), 2)
  160. otr_model.load_weights(model_path)
  161. if __name__ == '__main__':
  162. # if len(sys.argv) == 2:
  163. # port = int(sys.argv[1])
  164. # else:
  165. # port = 15017
  166. # app.run(host='0.0.0.0', port=port, threaded=True, debug=True)
  167. # log("OTR running "+str(port))
  168. table_detect("")