otr_interface.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import base64
  2. import json
  3. import multiprocessing as mp
  4. import os
  5. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  6. import sys
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. import time
  9. import traceback
  10. from multiprocessing.context import Process
  11. import multiprocessing
  12. import logging
  13. # from table_line import *
  14. import cv2
  15. import numpy as np
  16. import tensorflow as tf
  17. from flask import Flask, request
  18. from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes
  19. from otr.table_line import get_best_predict_size, table_line, get_points, get_split_line, get_points_row, \
  20. get_points_col, \
  21. delete_close_points, fix_outline, get_bbox, get_outline_point, table_net, delete_contain_bbox, points_to_line, \
  22. fix_inner, merge_line, fix_corner, add_continue_bbox, delete_outline
  23. from format_convert import _global
  24. # 接口配置
  25. app = Flask(__name__)
  26. @app.route('/otr', methods=['POST'])
  27. def _otr():
  28. _global._init()
  29. _global.update({"port": globals().get("port")})
  30. log("into otr_interface _otr")
  31. try:
  32. if not request.form:
  33. log("otr no data!")
  34. return json.dumps({"list_line": str([-9])})
  35. otr_model = globals().get("global_otr_model")
  36. if otr_model is None:
  37. otr_model = OtrModels().get_model()
  38. globals().update({"global_otr_model": otr_model})
  39. data = request.form.get("data")
  40. is_from_pdf = request.form.get("is_from_pdf")
  41. img_data = base64.b64decode(data)
  42. _md5 = get_md5_from_bytes(img_data)[0]
  43. _global.update({"md5": _md5})
  44. if is_from_pdf:
  45. list_lines = line_detect(img_data, otr_model, prob=0.2)
  46. else:
  47. list_lines = line_detect(img_data, otr_model, prob=0.5)
  48. return json.dumps(list_lines)
  49. except TimeoutError:
  50. return json.dumps({"list_line": str([-5])})
  51. except:
  52. traceback.print_exc()
  53. return json.dumps({"list_line": str([-1])})
  54. def otr(data, otr_model, is_from_pdf):
  55. try:
  56. img_data = base64.b64decode(data)
  57. # points_and_lines = pool.apply(table_detect, (img_data,))
  58. if is_from_pdf:
  59. list_lines = line_detect(img_data, otr_model, prob=0.2)
  60. else:
  61. list_lines = line_detect(img_data, otr_model, prob=0.5)
  62. return list_lines
  63. except TimeoutError:
  64. raise TimeoutError
  65. flag = 0
  66. # model_path = "models/table-line.h5"
  67. def table_detect2(img_data, otr_model):
  68. log("into otr_interface table_detect")
  69. start_time = time.time()
  70. try:
  71. start_time1 = time.time()
  72. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  73. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  74. # log("into otr_interface table_detect 1")
  75. # cv2.imwrite("111111.jpg", img)
  76. # 将bgr转为rbg
  77. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  78. # log("into otr_interface table_detect 2")
  79. # 选择与图片最接近分辨率,以防失真
  80. # best_h, best_w = get_best_predict_size(img)
  81. print("image_np.shape", image_np.shape)
  82. best_h, best_w, _ = image_np.shape
  83. log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  84. # 调用模型
  85. # rows, cols = table_line(image_np, otr_model)
  86. rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  87. start_time1 = time.time()
  88. if not rows or not cols:
  89. print("points", 0, "split_lines", 0, "bboxes", 0)
  90. return {"points": str([]), "split_lines": str([]),
  91. "bboxes": str([]), "outline_points": str([]),
  92. "lines": str([])}
  93. # 查看是否正确输出rows,cols
  94. # for line in rows+cols:
  95. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  96. # (255, 0, 0), 2)
  97. # cv2.imshow("rows-cols1", img)
  98. # cv2.waitKey(0)
  99. # 处理结果
  100. # 合并错开线
  101. rows = merge_line(rows, axis=0)
  102. cols = merge_line(cols, axis=1)
  103. # 计算交点、分割线
  104. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  105. # log("into otr_interface table_detect 5")
  106. if not points:
  107. print("points", 0, "split_lines", 0, "bboxes", 0)
  108. return {"points": str([]), "split_lines": str([]),
  109. "bboxes": str([]), "outline_points": str([]),
  110. "lines": str([])}
  111. # 清掉外围的没用的线
  112. rows, cols = delete_outline(rows, cols, points)
  113. split_lines, split_y = get_split_line(points, cols, image_np)
  114. # log("into otr_interface table_detect 6")
  115. # 计算交点所在行列,剔除相近交点
  116. row_point_list = get_points_row(points, split_y, 5)
  117. col_point_list = get_points_col(points, split_y, 5)
  118. # log("into otr_interface table_detect 7")
  119. points = delete_close_points(points, row_point_list, col_point_list)
  120. # log("into otr_interface table_detect 8")
  121. # 查看是否正确输出点
  122. # for p in points:
  123. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  124. # cv2.imshow("points", img)
  125. # cv2.waitKey(0)
  126. # 查看是否正确输出rows,cols
  127. # for line in rows+cols:
  128. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  129. # (0, 255, 0), 2)
  130. # cv2.imshow("rows-cols0", img)
  131. # cv2.waitKey(0)
  132. # 修复边框
  133. new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
  134. split_y)
  135. # print(new_cols, new_rows)
  136. if new_rows or new_cols:
  137. # 连接至补线的延长线
  138. if long_rows:
  139. rows = long_rows
  140. if long_cols:
  141. cols = long_cols
  142. # 新的补线
  143. if new_rows:
  144. rows += new_rows
  145. if new_cols:
  146. cols += new_cols
  147. # 修复边框后重新计算交点、分割线
  148. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  149. # log("into otr_interface table_detect 10")
  150. split_lines, split_y = get_split_line(points, cols, image_np)
  151. # 计算交点所在行列,剔除相近交点
  152. row_point_list = get_points_row(points, split_y, 0)
  153. col_point_list = get_points_col(points, split_y, 0)
  154. # log("into otr_interface table_detect 11")
  155. points = delete_close_points(points, row_point_list, col_point_list)
  156. # row_point_list = get_points_row(points, split_y)
  157. # col_point_list = get_points_col(points, split_y)
  158. # log("into otr_interface table_detect 12")
  159. # 查看是否正确输出rows,cols
  160. # for line in rows+cols:
  161. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  162. # (255, 0, 0), 2)
  163. # cv2.imshow("rows-cols1", img)
  164. # cv2.waitKey(0)
  165. # 修复表格4个角
  166. rows, cols = fix_corner(rows, cols, split_y)
  167. points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  168. # row_point_list = get_points_row(points, split_y, 5)
  169. # col_point_list = get_points_col(points, split_y, 5)
  170. # print("row_point_list", row_point_list)
  171. # print("col_point_list", col_point_list)
  172. # 修复内部缺线
  173. points = fix_inner(rows, cols, points, split_y)
  174. if not points:
  175. print("points", 0, "split_lines", 0, "bboxes", 0)
  176. return {"points": str([]), "split_lines": str([]),
  177. "bboxes": str([]), "outline_points": str([]),
  178. "lines": str([])}
  179. row_point_list = get_points_row(points, split_y, 5)
  180. col_point_list = get_points_col(points, split_y, 5)
  181. # 查看是否正确输出点
  182. # for p in points:
  183. # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
  184. # cv2.imshow("points fix", img)
  185. # cv2.waitKey(0)
  186. # 查看是否正确输出rows,cols
  187. # for line in rows+cols:
  188. # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  189. # (255, 0, 0), 2)
  190. # cv2.imshow("rows-cols2", img)
  191. # cv2.waitKey(0)
  192. # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
  193. # rows = points_to_line(row_point_list, axis=0)
  194. # cols = points_to_line(col_point_list, axis=1)
  195. # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  196. # row_point_list = get_points_row(points, split_y, 0)
  197. # col_point_list = get_points_col(points, split_y, 0)
  198. # 获取bbox 单元格
  199. bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
  200. # log("into otr_interface table_detect 13")
  201. # 删除包含bbox
  202. if bboxes:
  203. bboxes = delete_contain_bbox(bboxes)
  204. # 查看是否能输出正确框
  205. # for box in bboxes:
  206. # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
  207. # cv2.imshow("bbox", img)
  208. # cv2.waitKey(0)
  209. # 补充连续框
  210. # if bboxes:
  211. # bboxes = add_continue_bbox(bboxes)
  212. #
  213. # # 删除包含bbox
  214. # bboxes = delete_contain_bbox(bboxes)
  215. # 查看是否能输出正确框
  216. # cv2.namedWindow('bbox', 0)
  217. # for box in bboxes:
  218. # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
  219. # cv2.imshow("bbox", img)
  220. # cv2.waitKey(0)
  221. # 查看是否正确输出点
  222. # cv2.namedWindow('points', 0)
  223. # for p in points:
  224. # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  225. # cv2.imshow("points", img)
  226. # cv2.waitKey(0)
  227. # 查看是否正确输出区域分割线
  228. # cv2.namedWindow('split_lines', 0)
  229. # for line in split_lines:
  230. # cv2.line(img, line[0], line[1], (0, 0, 255), 2)
  231. # cv2.imshow("split_lines", img)
  232. # cv2.waitKey(0)
  233. # 获取每个表格的左上右下两个点
  234. outline_points = get_outline_point(points, split_y)
  235. # log("into otr_interface table_detect 14")
  236. if bboxes:
  237. print("bboxes number", len(bboxes))
  238. # print("bboxes", bboxes)
  239. else:
  240. print("bboxes number", "None")
  241. log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  242. log("otr finish: " + str(round(float(time.time()-start_time1), 4)) + "s")
  243. return {"points": str(points), "split_lines": str(split_lines),
  244. "bboxes": str(bboxes), "outline_points": str(outline_points),
  245. "lines": str(rows+cols)}
  246. except TimeoutError:
  247. raise TimeoutError
  248. except Exception as e:
  249. log("otr_interface cannot detected table!")
  250. print("otr_interface cannot detected table!", traceback.print_exc())
  251. print("points", 0, "split_lines", 0, "bboxes", 0)
  252. log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  253. return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  254. "outline_points": str([]), "lines": str([])}
  255. def line_detect(img_data, otr_model, prob=0.2):
  256. log("into otr_interface table_detect")
  257. start_time = time.time()
  258. try:
  259. start_time1 = time.time()
  260. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  261. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  262. # log("into otr_interface table_detect 1")
  263. # cv2.imwrite("111111.jpg", img)
  264. # 将bgr转为rbg
  265. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  266. # log("into otr_interface table_detect 2")
  267. # 选择与图片最接近分辨率,以防失真
  268. # best_h, best_w = get_best_predict_size(img)
  269. log("image_np.shape" + str(image_np.shape))
  270. best_h, best_w, _ = image_np.shape
  271. log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  272. # 调用模型
  273. # rows, cols = table_line(image_np, otr_model)
  274. start_time1 = time.time()
  275. list_line = table_line(image_np, otr_model, size=(best_w, best_h), prob=prob)
  276. log("otr finish " + str(round(float(time.time()-start_time1), 4)) + "s")
  277. return {"list_line": str(list_line)}
  278. except TimeoutError:
  279. raise TimeoutError
  280. except Exception as e:
  281. log("otr_interface cannot detected table!")
  282. print("otr_interface cannot detected table!", traceback.print_exc())
  283. log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  284. return {"list_line": str([])}
  285. class OtrModels:
  286. def __init__(self):
  287. # python文件所在目录
  288. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  289. model_path = _dir + "/models/table-line.h5"
  290. self.otr_model = table_net((None, None, 3), 2)
  291. self.otr_model.load_weights(model_path)
  292. def get_model(self):
  293. return self.otr_model
  294. def test_otr_model():
  295. file_path = "C:/Users/Administrator/Desktop/error2.png"
  296. with open(file_path, "rb") as f:
  297. file_bytes = f.read()
  298. file_base64 = base64.b64encode(file_bytes)
  299. file_json = {"data": file_base64, "is_from_pdf": False}
  300. _url = "http://192.168.2.103:18000/otr"
  301. r = json.loads(request_post(_url, file_json))
  302. print(r)
  303. # otr_model = table_net((None, None, 3), 2)
  304. # otr_model.load_weights(model_path)
  305. if __name__ == '__main__':
  306. if len(sys.argv) == 2:
  307. port = int(sys.argv[1])
  308. elif len(sys.argv) == 3:
  309. port = int(sys.argv[1])
  310. using_gpu_index = int(sys.argv[2])
  311. else:
  312. port = 18000
  313. using_gpu_index = 0
  314. _global._init()
  315. _global.update({"port": str(port)})
  316. globals().update({"port": str(port)})
  317. # 日志格式设置
  318. # ip = get_intranet_ip()
  319. # logging.basicConfig(level=logging.INFO,
  320. # format='%(asctime)s - %(name)s - %(levelname)s - '
  321. # + ip + ' - ' + str(port) + ' - %(message)s')
  322. # 限制tensorflow显存
  323. memory_limit_scale = 0.3
  324. os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
  325. os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
  326. os.environ['CUDA_CACHE_DISABLE'] = str(0)
  327. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=memory_limit_scale)
  328. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  329. app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
  330. # app.run(port=port)
  331. log("OTR running "+str(port))
  332. # test_otr_model()
  333. # print(json.dumps([-2]))
  334. # otr_model = OtrModels().get_model()
  335. # otr("11", otr_model)