otr_interface.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import base64
  2. import multiprocessing as mp
  3. import os
  4. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  5. import sys
  6. import time
  7. import traceback
  8. from multiprocessing.context import Process
  9. import multiprocessing
  10. import logging
  11. # from table_line import *
  12. import cv2
  13. import numpy as np
  14. import tensorflow as tf
  15. from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \
  16. get_points_col, \
  17. delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \
  18. fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline
  19. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  20. logger = logging.getLogger(__name__)
  21. def log(msg):
  22. """
  23. @summary:打印信息
  24. """
  25. logger.info(msg)
  26. def otr(data, otr_model, is_from_pdf):
  27. try:
  28. img_data = base64.b64decode(data)
  29. # points_and_lines = pool.apply(table_detect, (img_data,))
  30. if is_from_pdf:
  31. list_lines = line_detect(img_data, otr_model, prob=0.2)
  32. else:
  33. list_lines = line_detect(img_data, otr_model, prob=0.5)
  34. return list_lines
  35. except TimeoutError:
  36. raise TimeoutError
  37. flag = 0
  38. # model_path = "models/table-line.h5"
  39. def table_detect2(img_data, otr_model):
  40. logging.info("into otr_interface table_detect")
  41. start_time = time.time()
  42. try:
  43. start_time1 = time.time()
  44. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  45. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  46. # logging.info("into otr_interface table_detect 1")
  47. # cv2.imwrite("111111.jpg", img)
  48. # 将bgr转为rbg
  49. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  50. # logging.info("into otr_interface table_detect 2")
  51. # 选择与图片最接近分辨率,以防失真
  52. # best_h, best_w = get_best_predict_size(img)
  53. print("image_np.shape", image_np.shape)
  54. best_h, best_w, _ = image_np.shape
  55. logging.info("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  56. # 调用模型
  57. # rows, cols = table_line(image_np, otr_model)
  58. start_time1 = time.time()
  59. rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  60. logging.info("otr model predict time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  61. start_time1 = time.time()
  62. if not rows or not cols:
  63. print("points", 0, "split_lines", 0, "bboxes", 0)
  64. return {"points": str([]), "split_lines": str([]),
  65. "bboxes": str([]), "outline_points": str([]),
  66. "lines": str([])}
  67. # 查看是否正确输出rows,cols
  68. # for line in rows+cols:
  69. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  70. # (255, 0, 0), 2)
  71. # cv2.imshow("rows-cols1", img)
  72. # cv2.waitKey(0)
  73. # 处理结果
  74. # 合并错开线
  75. rows = merge_line(rows, axis=0)
  76. cols = merge_line(cols, axis=1)
  77. # 计算交点、分割线
  78. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  79. # logging.info("into otr_interface table_detect 5")
  80. if not points:
  81. print("points", 0, "split_lines", 0, "bboxes", 0)
  82. return {"points": str([]), "split_lines": str([]),
  83. "bboxes": str([]), "outline_points": str([]),
  84. "lines": str([])}
  85. # 清掉外围的没用的线
  86. rows, cols = delete_outline(rows, cols, points)
  87. split_lines, split_y = get_split_line(points, cols, image_np)
  88. # logging.info("into otr_interface table_detect 6")
  89. # 计算交点所在行列,剔除相近交点
  90. row_point_list = get_points_row(points, split_y, 5)
  91. col_point_list = get_points_col(points, split_y, 5)
  92. # logging.info("into otr_interface table_detect 7")
  93. points = delete_close_points(points, row_point_list, col_point_list)
  94. # logging.info("into otr_interface table_detect 8")
  95. # 查看是否正确输出点
  96. # for p in points:
  97. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  98. # cv2.imshow("points", img)
  99. # cv2.waitKey(0)
  100. # 查看是否正确输出rows,cols
  101. # for line in rows+cols:
  102. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  103. # (0, 255, 0), 2)
  104. # cv2.imshow("rows-cols0", img)
  105. # cv2.waitKey(0)
  106. # 修复边框
  107. new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
  108. split_y)
  109. # print(new_cols, new_rows)
  110. if new_rows or new_cols:
  111. # 连接至补线的延长线
  112. if long_rows:
  113. rows = long_rows
  114. if long_cols:
  115. cols = long_cols
  116. # 新的补线
  117. if new_rows:
  118. rows += new_rows
  119. if new_cols:
  120. cols += new_cols
  121. # 修复边框后重新计算交点、分割线
  122. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  123. # logging.info("into otr_interface table_detect 10")
  124. split_lines, split_y = get_split_line(points, cols, image_np)
  125. # 计算交点所在行列,剔除相近交点
  126. row_point_list = get_points_row(points, split_y, 0)
  127. col_point_list = get_points_col(points, split_y, 0)
  128. # logging.info("into otr_interface table_detect 11")
  129. points = delete_close_points(points, row_point_list, col_point_list)
  130. # row_point_list = get_points_row(points, split_y)
  131. # col_point_list = get_points_col(points, split_y)
  132. # logging.info("into otr_interface table_detect 12")
  133. # 查看是否正确输出rows,cols
  134. # for line in rows+cols:
  135. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  136. # (255, 0, 0), 2)
  137. # cv2.imshow("rows-cols1", img)
  138. # cv2.waitKey(0)
  139. # 修复表格4个角
  140. rows, cols = fix_corner(rows, cols, split_y)
  141. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  142. # row_point_list = get_points_row(points, split_y, 5)
  143. # col_point_list = get_points_col(points, split_y, 5)
  144. # print("row_point_list", row_point_list)
  145. # print("col_point_list", col_point_list)
  146. # 修复内部缺线
  147. points = fix_inner(rows, cols, points, split_y)
  148. if not points:
  149. print("points", 0, "split_lines", 0, "bboxes", 0)
  150. return {"points": str([]), "split_lines": str([]),
  151. "bboxes": str([]), "outline_points": str([]),
  152. "lines": str([])}
  153. row_point_list = get_points_row(points, split_y, 5)
  154. col_point_list = get_points_col(points, split_y, 5)
  155. # 查看是否正确输出点
  156. # for p in points:
  157. # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
  158. # cv2.imshow("points fix", img)
  159. # cv2.waitKey(0)
  160. # 查看是否正确输出rows,cols
  161. # for line in rows+cols:
  162. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  163. # (255, 0, 0), 2)
  164. # cv2.imshow("rows-cols2", img)
  165. # cv2.waitKey(0)
  166. # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
  167. # rows = points_to_line(row_point_list, axis=0)
  168. # cols = points_to_line(col_point_list, axis=1)
  169. # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  170. # row_point_list = get_points_row(points, split_y, 0)
  171. # col_point_list = get_points_col(points, split_y, 0)
  172. # 获取bbox 单元格
  173. bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
  174. # logging.info("into otr_interface table_detect 13")
  175. # 删除包含bbox
  176. if bboxes:
  177. bboxes = delete_contain_bbox(bboxes)
  178. # 查看是否能输出正确框
  179. # for box in bboxes:
  180. # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
  181. # cv2.imshow("bbox", img)
  182. # cv2.waitKey(0)
  183. # 补充连续框
  184. # if bboxes:
  185. # bboxes = add_continue_bbox(bboxes)
  186. #
  187. # # 删除包含bbox
  188. # bboxes = delete_contain_bbox(bboxes)
  189. # 查看是否能输出正确框
  190. # cv2.namedWindow('bbox', 0)
  191. # for box in bboxes:
  192. # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
  193. # cv2.imshow("bbox", img)
  194. # cv2.waitKey(0)
  195. # 查看是否正确输出点
  196. # cv2.namedWindow('points', 0)
  197. # for p in points:
  198. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  199. # cv2.imshow("points", img)
  200. # cv2.waitKey(0)
  201. # 查看是否正确输出区域分割线
  202. # cv2.namedWindow('split_lines', 0)
  203. # for line in split_lines:
  204. # cv2.line(img, line[0], line[1], (0, 0, 255), 2)
  205. # cv2.imshow("split_lines", img)
  206. # cv2.waitKey(0)
  207. # 获取每个表格的左上右下两个点
  208. outline_points = get_outline_point(points, split_y)
  209. # logging.info("into otr_interface table_detect 14")
  210. if bboxes:
  211. print("bboxes number", len(bboxes))
  212. # print("bboxes", bboxes)
  213. else:
  214. print("bboxes number", "None")
  215. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  216. logging.info("use time: " + str(time.time()-start_time))
  217. return {"points": str(points), "split_lines": str(split_lines),
  218. "bboxes": str(bboxes), "outline_points": str(outline_points),
  219. "lines": str(rows+cols)}
  220. except TimeoutError:
  221. raise TimeoutError
  222. except Exception as e:
  223. logging.info("otr_interface cannot detected table!")
  224. print("otr_interface cannot detected table!", traceback.print_exc())
  225. print("points", 0, "split_lines", 0, "bboxes", 0)
  226. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  227. return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  228. "outline_points": str([]), "lines": str([])}
  229. def line_detect(img_data, otr_model, prob=0.2):
  230. logging.info("into otr_interface table_detect")
  231. start_time = time.time()
  232. try:
  233. start_time1 = time.time()
  234. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  235. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  236. # logging.info("into otr_interface table_detect 1")
  237. # cv2.imwrite("111111.jpg", img)
  238. # 将bgr转为rbg
  239. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  240. # logging.info("into otr_interface table_detect 2")
  241. # 选择与图片最接近分辨率,以防失真
  242. # best_h, best_w = get_best_predict_size(img)
  243. logging.info("image_np.shape" + str(image_np.shape))
  244. best_h, best_w, _ = image_np.shape
  245. logging.info("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  246. # 调用模型
  247. # rows, cols = table_line(image_np, otr_model)
  248. start_time1 = time.time()
  249. list_line = table_line(image_np, otr_model, size=(best_w, best_h), prob=prob)
  250. logging.info("otr finish " + str(round(float(time.time()-start_time1), 4)) + "s")
  251. return {"list_line": str(list_line)}
  252. except TimeoutError:
  253. raise TimeoutError
  254. except Exception as e:
  255. logging.info("otr_interface cannot detected table!")
  256. print("otr_interface cannot detected table!", traceback.print_exc())
  257. logging.info("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  258. return {"list_line": str([])}
  259. class OtrModels:
  260. def __init__(self):
  261. # python文件所在目录
  262. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  263. model_path = _dir + "/models/table-line.h5"
  264. self.otr_model = table_net((None, None, 3), 2)
  265. self.otr_model.load_weights(model_path)
  266. def get_model(self):
  267. return self.otr_model
  268. # otr_model = table_net((None, None, 3), 2)
  269. # otr_model.load_weights(model_path)
  270. if __name__ == '__main__':
  271. # if len(sys.argv) == 2:
  272. # port = int(sys.argv[1])
  273. # else:
  274. # port = 15017
  275. # app.run(host='0.0.0.0', port=port, threaded=False, debug=False)
  276. # log("OTR running "+str(port))
  277. otr_model = OtrModels().get_model()
  278. otr("11", otr_model)