otr_interface.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. import base64
  2. import multiprocessing as mp
  3. import os
  4. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  5. import sys
  6. import time
  7. import traceback
  8. from multiprocessing.context import Process
  9. import multiprocessing
  10. import logging
  11. # from table_line import *
  12. import cv2
  13. import numpy as np
  14. import tensorflow as tf
  15. from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \
  16. get_points_col, \
  17. delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \
  18. fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline
  19. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  20. logger = logging.getLogger(__name__)
  21. def log(msg):
  22. """
  23. @summary:打印信息
  24. """
  25. logger.info(msg)
  26. def otr(data, otr_model):
  27. try:
  28. img_data = base64.b64decode(data)
  29. # points_and_lines = pool.apply(table_detect, (img_data,))
  30. points_and_lines = table_detect(img_data, otr_model)
  31. return points_and_lines
  32. except TimeoutError:
  33. raise TimeoutError
  34. flag = 0
  35. # model_path = "models/table-line.h5"
  36. def table_detect(img_data, otr_model):
  37. logging.info("into otr_interface table_detect")
  38. start_time = time.time()
  39. try:
  40. start_time1 = time.time()
  41. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  42. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  43. # logging.info("into otr_interface table_detect 1")
  44. # cv2.imwrite("111111.jpg", img)
  45. # 将bgr转为rbg
  46. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  47. # logging.info("into otr_interface table_detect 2")
  48. # 选择与图片最接近分辨率,以防失真
  49. # best_h, best_w = get_best_predict_size(img)
  50. print("image_np.shape", image_np.shape)
  51. best_h, best_w, _ = image_np.shape
  52. logging.info("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  53. # 调用模型
  54. # rows, cols = table_line(image_np, otr_model)
  55. start_time1 = time.time()
  56. rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  57. logging.info("otr model predict time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  58. start_time1 = time.time()
  59. if not rows or not cols:
  60. print("points", 0, "split_lines", 0, "bboxes", 0)
  61. return {"points": str([]), "split_lines": str([]),
  62. "bboxes": str([]), "outline_points": str([])}
  63. # 查看是否正确输出rows,cols
  64. # for line in rows+cols:
  65. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  66. # (255, 0, 0), 2)
  67. # cv2.imshow("rows-cols1", img)
  68. # cv2.waitKey(0)
  69. # 处理结果
  70. # 合并错开线
  71. rows = merge_line(rows, axis=0)
  72. cols = merge_line(cols, axis=1)
  73. # 计算交点、分割线
  74. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  75. # logging.info("into otr_interface table_detect 5")
  76. if not points:
  77. print("points", 0, "split_lines", 0, "bboxes", 0)
  78. return {"points": str([]), "split_lines": str([]),
  79. "bboxes": str([]), "outline_points": str([])}
  80. # 清掉外围的没用的线
  81. rows, cols = delete_outline(rows, cols, points)
  82. split_lines, split_y = get_split_line(points, cols, image_np)
  83. # logging.info("into otr_interface table_detect 6")
  84. # 计算交点所在行列,剔除相近交点
  85. row_point_list = get_points_row(points, split_y, 5)
  86. col_point_list = get_points_col(points, split_y, 5)
  87. # logging.info("into otr_interface table_detect 7")
  88. points = delete_close_points(points, row_point_list, col_point_list)
  89. # logging.info("into otr_interface table_detect 8")
  90. # 查看是否正确输出点
  91. # for p in points:
  92. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  93. # cv2.imshow("points", img)
  94. # cv2.waitKey(0)
  95. # 查看是否正确输出rows,cols
  96. # for line in rows+cols:
  97. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  98. # (0, 255, 0), 2)
  99. # cv2.imshow("rows-cols0", img)
  100. # cv2.waitKey(0)
  101. # 修复边框
  102. new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
  103. split_y)
  104. # print(new_cols, new_rows)
  105. if new_rows or new_cols:
  106. # 连接至补线的延长线
  107. if long_rows:
  108. rows = long_rows
  109. if long_cols:
  110. cols = long_cols
  111. # 新的补线
  112. if new_rows:
  113. rows += new_rows
  114. if new_cols:
  115. cols += new_cols
  116. # 修复边框后重新计算交点、分割线
  117. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  118. # logging.info("into otr_interface table_detect 10")
  119. split_lines, split_y = get_split_line(points, cols, image_np)
  120. # 计算交点所在行列,剔除相近交点
  121. row_point_list = get_points_row(points, split_y, 0)
  122. col_point_list = get_points_col(points, split_y, 0)
  123. # logging.info("into otr_interface table_detect 11")
  124. points = delete_close_points(points, row_point_list, col_point_list)
  125. # row_point_list = get_points_row(points, split_y)
  126. # col_point_list = get_points_col(points, split_y)
  127. # logging.info("into otr_interface table_detect 12")
  128. # 查看是否正确输出rows,cols
  129. # for line in rows+cols:
  130. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  131. # (255, 0, 0), 2)
  132. # cv2.imshow("rows-cols1", img)
  133. # cv2.waitKey(0)
  134. # 修复表格4个角
  135. rows, cols = fix_corner(rows, cols, split_y)
  136. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  137. # row_point_list = get_points_row(points, split_y, 5)
  138. # col_point_list = get_points_col(points, split_y, 5)
  139. # print("row_point_list", row_point_list)
  140. # print("col_point_list", col_point_list)
  141. # 修复内部缺线
  142. points = fix_inner(rows, cols, points, split_y)
  143. if not points:
  144. print("points", 0, "split_lines", 0, "bboxes", 0)
  145. return {"points": str([]), "split_lines": str([]),
  146. "bboxes": str([]), "outline_points": str([])}
  147. row_point_list = get_points_row(points, split_y, 5)
  148. col_point_list = get_points_col(points, split_y, 5)
  149. # 查看是否正确输出点
  150. # for p in points:
  151. # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
  152. # cv2.imshow("points fix", img)
  153. # cv2.waitKey(0)
  154. # 查看是否正确输出rows,cols
  155. # for line in rows+cols:
  156. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  157. # (255, 0, 0), 2)
  158. # cv2.imshow("rows-cols2", img)
  159. # cv2.waitKey(0)
  160. # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
  161. # rows = points_to_line(row_point_list, axis=0)
  162. # cols = points_to_line(col_point_list, axis=1)
  163. # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  164. # row_point_list = get_points_row(points, split_y, 0)
  165. # col_point_list = get_points_col(points, split_y, 0)
  166. # 获取bbox 单元格
  167. bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
  168. # logging.info("into otr_interface table_detect 13")
  169. # 删除包含bbox
  170. if bboxes:
  171. bboxes = delete_contain_bbox(bboxes)
  172. # 查看是否能输出正确框
  173. # for box in bboxes:
  174. # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
  175. # cv2.imshow("bbox", img)
  176. # cv2.waitKey(0)
  177. # 补充连续框
  178. # if bboxes:
  179. # bboxes = add_continue_bbox(bboxes)
  180. #
  181. # # 删除包含bbox
  182. # bboxes = delete_contain_bbox(bboxes)
  183. # 查看是否能输出正确框
  184. # cv2.namedWindow('bbox', 0)
  185. # for box in bboxes:
  186. # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
  187. # cv2.imshow("bbox", img)
  188. # cv2.waitKey(0)
  189. # 查看是否正确输出点
  190. # cv2.namedWindow('points', 0)
  191. # for p in points:
  192. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  193. # cv2.imshow("points", img)
  194. # cv2.waitKey(0)
  195. # 查看是否正确输出区域分割线
  196. # cv2.namedWindow('split_lines', 0)
  197. # for line in split_lines:
  198. # cv2.line(img, line[0], line[1], (0, 0, 255), 2)
  199. # cv2.imshow("split_lines", img)
  200. # cv2.waitKey(0)
  201. # 获取每个表格的左上右下两个点
  202. outline_points = get_outline_point(points, split_y)
  203. # logging.info("into otr_interface table_detect 14")
  204. if bboxes:
  205. print("bboxes number", len(bboxes))
  206. # print("bboxes", bboxes)
  207. else:
  208. print("bboxes number", "None")
  209. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  210. logging.info("use time: " + str(time.time()-start_time))
  211. return {"points": str(points), "split_lines": str(split_lines),
  212. "bboxes": str(bboxes), "outline_points": str(outline_points)}
  213. except TimeoutError:
  214. raise TimeoutError
  215. except Exception as e:
  216. logging.info("otr_interface cannot detected table!")
  217. print("otr_interface cannot detected table!", traceback.print_exc())
  218. print("points", 0, "split_lines", 0, "bboxes", 0)
  219. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  220. return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  221. "outline_points": str([])}
  222. class OtrModels:
  223. def __init__(self):
  224. # python文件所在目录
  225. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  226. model_path = _dir + "/models/table-line.h5"
  227. self.otr_model = table_net((None, None, 3), 2)
  228. self.otr_model.load_weights(model_path)
  229. def get_model(self):
  230. return self.otr_model
  231. # otr_model = table_net((None, None, 3), 2)
  232. # otr_model.load_weights(model_path)
  233. if __name__ == '__main__':
  234. # if len(sys.argv) == 2:
  235. # port = int(sys.argv[1])
  236. # else:
  237. # port = 15017
  238. # app.run(host='0.0.0.0', port=port, threaded=False, debug=False)
  239. # log("OTR running "+str(port))
  240. otr_model = OtrModels().get_model()
  241. otr("11", otr_model)