otr_interface.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import base64
  2. import multiprocessing as mp
  3. import os
  4. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  5. import sys
  6. import time
  7. import traceback
  8. from multiprocessing.context import Process
  9. import multiprocessing
  10. import logging
  11. # from table_line import *
  12. import cv2
  13. import numpy as np
  14. import tensorflow as tf
  15. from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \
  16. get_points_col, \
  17. delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \
  18. fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline
  19. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  20. logger = logging.getLogger(__name__)
  21. def log(msg):
  22. """
  23. @summary:打印信息
  24. """
  25. logger.info(msg)
  26. def otr(data, otr_model):
  27. try:
  28. img_data = base64.b64decode(data)
  29. # points_and_lines = pool.apply(table_detect, (img_data,))
  30. list_lines = line_detect(img_data, otr_model)
  31. return list_lines
  32. except TimeoutError:
  33. raise TimeoutError
  34. flag = 0
  35. # model_path = "models/table-line.h5"
  36. def table_detect2(img_data, otr_model):
  37. logging.info("into otr_interface table_detect")
  38. start_time = time.time()
  39. try:
  40. start_time1 = time.time()
  41. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  42. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  43. # logging.info("into otr_interface table_detect 1")
  44. # cv2.imwrite("111111.jpg", img)
  45. # 将bgr转为rbg
  46. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  47. # logging.info("into otr_interface table_detect 2")
  48. # 选择与图片最接近分辨率,以防失真
  49. # best_h, best_w = get_best_predict_size(img)
  50. print("image_np.shape", image_np.shape)
  51. best_h, best_w, _ = image_np.shape
  52. logging.info("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  53. # 调用模型
  54. # rows, cols = table_line(image_np, otr_model)
  55. start_time1 = time.time()
  56. rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  57. logging.info("otr model predict time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  58. start_time1 = time.time()
  59. if not rows or not cols:
  60. print("points", 0, "split_lines", 0, "bboxes", 0)
  61. return {"points": str([]), "split_lines": str([]),
  62. "bboxes": str([]), "outline_points": str([]),
  63. "lines": str([])}
  64. # 查看是否正确输出rows,cols
  65. # for line in rows+cols:
  66. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  67. # (255, 0, 0), 2)
  68. # cv2.imshow("rows-cols1", img)
  69. # cv2.waitKey(0)
  70. # 处理结果
  71. # 合并错开线
  72. rows = merge_line(rows, axis=0)
  73. cols = merge_line(cols, axis=1)
  74. # 计算交点、分割线
  75. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  76. # logging.info("into otr_interface table_detect 5")
  77. if not points:
  78. print("points", 0, "split_lines", 0, "bboxes", 0)
  79. return {"points": str([]), "split_lines": str([]),
  80. "bboxes": str([]), "outline_points": str([]),
  81. "lines": str([])}
  82. # 清掉外围的没用的线
  83. rows, cols = delete_outline(rows, cols, points)
  84. split_lines, split_y = get_split_line(points, cols, image_np)
  85. # logging.info("into otr_interface table_detect 6")
  86. # 计算交点所在行列,剔除相近交点
  87. row_point_list = get_points_row(points, split_y, 5)
  88. col_point_list = get_points_col(points, split_y, 5)
  89. # logging.info("into otr_interface table_detect 7")
  90. points = delete_close_points(points, row_point_list, col_point_list)
  91. # logging.info("into otr_interface table_detect 8")
  92. # 查看是否正确输出点
  93. # for p in points:
  94. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  95. # cv2.imshow("points", img)
  96. # cv2.waitKey(0)
  97. # 查看是否正确输出rows,cols
  98. # for line in rows+cols:
  99. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  100. # (0, 255, 0), 2)
  101. # cv2.imshow("rows-cols0", img)
  102. # cv2.waitKey(0)
  103. # 修复边框
  104. new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
  105. split_y)
  106. # print(new_cols, new_rows)
  107. if new_rows or new_cols:
  108. # 连接至补线的延长线
  109. if long_rows:
  110. rows = long_rows
  111. if long_cols:
  112. cols = long_cols
  113. # 新的补线
  114. if new_rows:
  115. rows += new_rows
  116. if new_cols:
  117. cols += new_cols
  118. # 修复边框后重新计算交点、分割线
  119. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  120. # logging.info("into otr_interface table_detect 10")
  121. split_lines, split_y = get_split_line(points, cols, image_np)
  122. # 计算交点所在行列,剔除相近交点
  123. row_point_list = get_points_row(points, split_y, 0)
  124. col_point_list = get_points_col(points, split_y, 0)
  125. # logging.info("into otr_interface table_detect 11")
  126. points = delete_close_points(points, row_point_list, col_point_list)
  127. # row_point_list = get_points_row(points, split_y)
  128. # col_point_list = get_points_col(points, split_y)
  129. # logging.info("into otr_interface table_detect 12")
  130. # 查看是否正确输出rows,cols
  131. # for line in rows+cols:
  132. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  133. # (255, 0, 0), 2)
  134. # cv2.imshow("rows-cols1", img)
  135. # cv2.waitKey(0)
  136. # 修复表格4个角
  137. rows, cols = fix_corner(rows, cols, split_y)
  138. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  139. # row_point_list = get_points_row(points, split_y, 5)
  140. # col_point_list = get_points_col(points, split_y, 5)
  141. # print("row_point_list", row_point_list)
  142. # print("col_point_list", col_point_list)
  143. # 修复内部缺线
  144. points = fix_inner(rows, cols, points, split_y)
  145. if not points:
  146. print("points", 0, "split_lines", 0, "bboxes", 0)
  147. return {"points": str([]), "split_lines": str([]),
  148. "bboxes": str([]), "outline_points": str([]),
  149. "lines": str([])}
  150. row_point_list = get_points_row(points, split_y, 5)
  151. col_point_list = get_points_col(points, split_y, 5)
  152. # 查看是否正确输出点
  153. # for p in points:
  154. # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
  155. # cv2.imshow("points fix", img)
  156. # cv2.waitKey(0)
  157. # 查看是否正确输出rows,cols
  158. # for line in rows+cols:
  159. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  160. # (255, 0, 0), 2)
  161. # cv2.imshow("rows-cols2", img)
  162. # cv2.waitKey(0)
  163. # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
  164. # rows = points_to_line(row_point_list, axis=0)
  165. # cols = points_to_line(col_point_list, axis=1)
  166. # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  167. # row_point_list = get_points_row(points, split_y, 0)
  168. # col_point_list = get_points_col(points, split_y, 0)
  169. # 获取bbox 单元格
  170. bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
  171. # logging.info("into otr_interface table_detect 13")
  172. # 删除包含bbox
  173. if bboxes:
  174. bboxes = delete_contain_bbox(bboxes)
  175. # 查看是否能输出正确框
  176. # for box in bboxes:
  177. # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
  178. # cv2.imshow("bbox", img)
  179. # cv2.waitKey(0)
  180. # 补充连续框
  181. # if bboxes:
  182. # bboxes = add_continue_bbox(bboxes)
  183. #
  184. # # 删除包含bbox
  185. # bboxes = delete_contain_bbox(bboxes)
  186. # 查看是否能输出正确框
  187. # cv2.namedWindow('bbox', 0)
  188. # for box in bboxes:
  189. # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
  190. # cv2.imshow("bbox", img)
  191. # cv2.waitKey(0)
  192. # 查看是否正确输出点
  193. # cv2.namedWindow('points', 0)
  194. # for p in points:
  195. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  196. # cv2.imshow("points", img)
  197. # cv2.waitKey(0)
  198. # 查看是否正确输出区域分割线
  199. # cv2.namedWindow('split_lines', 0)
  200. # for line in split_lines:
  201. # cv2.line(img, line[0], line[1], (0, 0, 255), 2)
  202. # cv2.imshow("split_lines", img)
  203. # cv2.waitKey(0)
  204. # 获取每个表格的左上右下两个点
  205. outline_points = get_outline_point(points, split_y)
  206. # logging.info("into otr_interface table_detect 14")
  207. if bboxes:
  208. print("bboxes number", len(bboxes))
  209. # print("bboxes", bboxes)
  210. else:
  211. print("bboxes number", "None")
  212. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  213. logging.info("use time: " + str(time.time()-start_time))
  214. return {"points": str(points), "split_lines": str(split_lines),
  215. "bboxes": str(bboxes), "outline_points": str(outline_points),
  216. "lines": str(rows+cols)}
  217. except TimeoutError:
  218. raise TimeoutError
  219. except Exception as e:
  220. logging.info("otr_interface cannot detected table!")
  221. print("otr_interface cannot detected table!", traceback.print_exc())
  222. print("points", 0, "split_lines", 0, "bboxes", 0)
  223. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  224. return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  225. "outline_points": str([]), "lines": str([])}
  226. def line_detect(img_data, otr_model):
  227. logging.info("into otr_interface table_detect")
  228. start_time = time.time()
  229. try:
  230. start_time1 = time.time()
  231. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  232. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  233. # logging.info("into otr_interface table_detect 1")
  234. # cv2.imwrite("111111.jpg", img)
  235. # 将bgr转为rbg
  236. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  237. # logging.info("into otr_interface table_detect 2")
  238. # 选择与图片最接近分辨率,以防失真
  239. # best_h, best_w = get_best_predict_size(img)
  240. print("image_np.shape", image_np.shape)
  241. best_h, best_w, _ = image_np.shape
  242. logging.info("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  243. # 调用模型
  244. # rows, cols = table_line(image_np, otr_model)
  245. start_time1 = time.time()
  246. list_line = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  247. return {"list_line": str(list_line)}
  248. except TimeoutError:
  249. raise TimeoutError
  250. except Exception as e:
  251. logging.info("otr_interface cannot detected table!")
  252. print("otr_interface cannot detected table!", traceback.print_exc())
  253. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  254. return {"list_line": str([])}
  255. class OtrModels:
  256. def __init__(self):
  257. # python文件所在目录
  258. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  259. model_path = _dir + "/models/table-line.h5"
  260. self.otr_model = table_net((None, None, 3), 2)
  261. self.otr_model.load_weights(model_path)
  262. def get_model(self):
  263. return self.otr_model
  264. # otr_model = table_net((None, None, 3), 2)
  265. # otr_model.load_weights(model_path)
  266. if __name__ == '__main__':
  267. # if len(sys.argv) == 2:
  268. # port = int(sys.argv[1])
  269. # else:
  270. # port = 15017
  271. # app.run(host='0.0.0.0', port=port, threaded=False, debug=False)
  272. # log("OTR running "+str(port))
  273. otr_model = OtrModels().get_model()
  274. otr("11", otr_model)