otr_interface.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. import base64
  2. import json
  3. import multiprocessing as mp
  4. import os
  5. import sys
  6. import traceback
  7. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  8. # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  10. from format_convert.max_compute_config import max_compute
  11. import tensorflow as tf
  12. MAX_COMPUTE = max_compute
  13. if not MAX_COMPUTE:
  14. # tensorflow 内存设置
  15. try:
  16. gpus = tf.config.list_physical_devices('GPU')
  17. if len(gpus) > 0:
  18. tf.config.experimental.set_virtual_device_configuration(
  19. gpus[0],
  20. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)])
  21. except:
  22. traceback.print_exc()
  23. # pass
  24. # gpus = tf.config.list_physical_devices('GPU')
  25. # for gpu in gpus: # 如果使用多块GPU时
  26. # tf.config.experimental.set_memory_growth(gpu, True)
  27. os.environ['CUDA_CACHE_MAXSIZE'] = str(2147483648)
  28. os.environ['CUDA_CACHE_DISABLE'] = str(0)
  29. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
  30. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  31. import time
  32. import logging
  33. import cv2
  34. import numpy as np
  35. from flask import Flask, request
  36. from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform
  37. from otr.table_line import table_net
  38. from otr.table_line_new import table_line
  39. from format_convert import _global
  40. # 接口配置
  41. app = Flask(__name__)
  42. @app.route('/otr', methods=['POST'])
  43. def _otr():
  44. _global._init()
  45. _global.update({"port": globals().get("port")})
  46. start_time = time.time()
  47. log("into otr_interface _otr")
  48. try:
  49. if not request.form:
  50. log("otr no data!")
  51. return json.dumps({"list_line": str([-9])})
  52. otr_model = globals().get("global_otr_model")
  53. if otr_model is None:
  54. otr_model = OtrModels().get_model()
  55. globals().update({"global_otr_model": otr_model})
  56. data = request.form.get("data")
  57. is_from_pdf = request.form.get("is_from_pdf")
  58. _md5 = request.form.get("md5")
  59. _global.update({"md5": _md5})
  60. list_lines = otr(data, otr_model, is_from_pdf)
  61. return json.dumps(list_lines)
  62. except TimeoutError:
  63. return json.dumps({"list_line": str([-5])})
  64. except:
  65. traceback.print_exc()
  66. return json.dumps({"list_line": str([-1])})
  67. finally:
  68. log("otr interface finish time " + str(time.time()-start_time))
  69. def otr(data, otr_model, is_from_pdf):
  70. log("into otr_interface otr")
  71. log('otr is_from_pdf ' + str(is_from_pdf))
  72. try:
  73. img_data = base64.b64decode(data)
  74. if str(is_from_pdf) == 'True':
  75. # log('otr is_from_pdf ' + str(type(is_from_pdf)))
  76. list_lines = line_detect(img_data, otr_model, prob=0.2)
  77. else:
  78. # log('otr is_from_pdf ' + str(type(is_from_pdf)))
  79. list_lines = line_detect(img_data, otr_model, prob=0.5)
  80. return list_lines
  81. except TimeoutError:
  82. raise TimeoutError
  83. # def table_detect2(img_data, otr_model):
  84. # log("into otr_interface table_detect")
  85. # start_time = time.time()
  86. # try:
  87. # start_time1 = time.time()
  88. # # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  89. # img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  90. # # log("into otr_interface table_detect 1")
  91. # # cv2.imwrite("111111.jpg", img)
  92. #
  93. # # 将bgr转为rbg
  94. # image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  95. # # log("into otr_interface table_detect 2")
  96. #
  97. # # 选择与图片最接近分辨率,以防失真
  98. # # best_h, best_w = get_best_predict_size(img)
  99. # print("image_np.shape", image_np.shape)
  100. # best_h, best_w, _ = image_np.shape
  101. # log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  102. #
  103. # # 调用模型
  104. # # rows, cols = table_line(image_np, otr_model)
  105. # rows, cols, image_np = table_line(image_np, otr_model, size=(best_w, best_h), hprob=0.5, vprob=0.5)
  106. #
  107. # start_time1 = time.time()
  108. # if not rows or not cols:
  109. # print("points", 0, "split_lines", 0, "bboxes", 0)
  110. # return {"points": str([]), "split_lines": str([]),
  111. # "bboxes": str([]), "outline_points": str([]),
  112. # "lines": str([])}
  113. #
  114. # # 查看是否正确输出rows,cols
  115. # # for line in rows+cols:
  116. # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  117. # # (255, 0, 0), 2)
  118. # # cv2.imshow("rows-cols1", img)
  119. # # cv2.waitKey(0)
  120. #
  121. # # 处理结果
  122. # # 合并错开线
  123. # rows = merge_line(rows, axis=0)
  124. # cols = merge_line(cols, axis=1)
  125. #
  126. # # 计算交点、分割线
  127. # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  128. # # log("into otr_interface table_detect 5")
  129. # if not points:
  130. # print("points", 0, "split_lines", 0, "bboxes", 0)
  131. # return {"points": str([]), "split_lines": str([]),
  132. # "bboxes": str([]), "outline_points": str([]),
  133. # "lines": str([])}
  134. #
  135. # # 清掉外围的没用的线
  136. # rows, cols = delete_outline(rows, cols, points)
  137. #
  138. # split_lines, split_y = get_split_line(points, cols, image_np)
  139. # # log("into otr_interface table_detect 6")
  140. #
  141. # # 计算交点所在行列,剔除相近交点
  142. # row_point_list = get_points_row(points, split_y, 5)
  143. # col_point_list = get_points_col(points, split_y, 5)
  144. # # log("into otr_interface table_detect 7")
  145. # points = delete_close_points(points, row_point_list, col_point_list)
  146. # # log("into otr_interface table_detect 8")
  147. #
  148. # # 查看是否正确输出点
  149. # # for p in points:
  150. # # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  151. # # cv2.imshow("points", img)
  152. # # cv2.waitKey(0)
  153. #
  154. # # 查看是否正确输出rows,cols
  155. # # for line in rows+cols:
  156. # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  157. # # (0, 255, 0), 2)
  158. # # cv2.imshow("rows-cols0", img)
  159. # # cv2.waitKey(0)
  160. #
  161. # # 修复边框
  162. # new_rows, new_cols, long_rows, long_cols = fix_outline(image_np, rows, cols, points,
  163. # split_y)
  164. # # print(new_cols, new_rows)
  165. # if new_rows or new_cols:
  166. # # 连接至补线的延长线
  167. # if long_rows:
  168. # rows = long_rows
  169. # if long_cols:
  170. # cols = long_cols
  171. # # 新的补线
  172. # if new_rows:
  173. # rows += new_rows
  174. # if new_cols:
  175. # cols += new_cols
  176. #
  177. # # 修复边框后重新计算交点、分割线
  178. # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  179. #
  180. # # log("into otr_interface table_detect 10")
  181. # split_lines, split_y = get_split_line(points, cols, image_np)
  182. #
  183. # # 计算交点所在行列,剔除相近交点
  184. # row_point_list = get_points_row(points, split_y, 0)
  185. # col_point_list = get_points_col(points, split_y, 0)
  186. # # log("into otr_interface table_detect 11")
  187. # points = delete_close_points(points, row_point_list, col_point_list)
  188. # # row_point_list = get_points_row(points, split_y)
  189. # # col_point_list = get_points_col(points, split_y)
  190. # # log("into otr_interface table_detect 12")
  191. #
  192. # # 查看是否正确输出rows,cols
  193. # # for line in rows+cols:
  194. # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  195. # # (255, 0, 0), 2)
  196. # # cv2.imshow("rows-cols1", img)
  197. # # cv2.waitKey(0)
  198. #
  199. # # 修复表格4个角
  200. # rows, cols = fix_corner(rows, cols, split_y)
  201. # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  202. # # row_point_list = get_points_row(points, split_y, 5)
  203. # # col_point_list = get_points_col(points, split_y, 5)
  204. #
  205. # # print("row_point_list", row_point_list)
  206. # # print("col_point_list", col_point_list)
  207. #
  208. # # 修复内部缺线
  209. # points = fix_inner(rows, cols, points, split_y)
  210. # if not points:
  211. # print("points", 0, "split_lines", 0, "bboxes", 0)
  212. # return {"points": str([]), "split_lines": str([]),
  213. # "bboxes": str([]), "outline_points": str([]),
  214. # "lines": str([])}
  215. # row_point_list = get_points_row(points, split_y, 5)
  216. # col_point_list = get_points_col(points, split_y, 5)
  217. #
  218. # # 查看是否正确输出点
  219. # # for p in points:
  220. # # cv2.circle(img, (p[0], p[1]), 1, (0, 255, 0), 3)
  221. # # cv2.imshow("points fix", img)
  222. # # cv2.waitKey(0)
  223. #
  224. # # 查看是否正确输出rows,cols
  225. # # for line in rows+cols:
  226. # # cv2.line(img, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  227. # # (255, 0, 0), 2)
  228. # # cv2.imshow("rows-cols2", img)
  229. # # cv2.waitKey(0)
  230. #
  231. # # 根据分行分列重新得到rows、cols,避免线延长导致后续bbox生成失败
  232. # # rows = points_to_line(row_point_list, axis=0)
  233. # # cols = points_to_line(col_point_list, axis=1)
  234. # # points = get_points(rows, cols, (image_np.shape[0], image_np.shape[1]))
  235. # # row_point_list = get_points_row(points, split_y, 0)
  236. # # col_point_list = get_points_col(points, split_y, 0)
  237. #
  238. # # 获取bbox 单元格
  239. # bboxes = get_bbox(image_np, row_point_list, col_point_list, split_y, rows, cols)
  240. # # log("into otr_interface table_detect 13")
  241. #
  242. # # 删除包含bbox
  243. # if bboxes:
  244. # bboxes = delete_contain_bbox(bboxes)
  245. #
  246. # # 查看是否能输出正确框
  247. # # for box in bboxes:
  248. # # cv2.rectangle(img, box[0], box[1], (0, 0, 255), 3)
  249. # # cv2.imshow("bbox", img)
  250. # # cv2.waitKey(0)
  251. #
  252. # # 补充连续框
  253. # # if bboxes:
  254. # # bboxes = add_continue_bbox(bboxes)
  255. # #
  256. # # # 删除包含bbox
  257. # # bboxes = delete_contain_bbox(bboxes)
  258. #
  259. # # 查看是否能输出正确框
  260. # # cv2.namedWindow('bbox', 0)
  261. # # for box in bboxes:
  262. # # cv2.rectangle(img, box[0], box[1], (0, 255, 0), 3)
  263. # # cv2.imshow("bbox", img)
  264. # # cv2.waitKey(0)
  265. #
  266. # # 查看是否正确输出点
  267. # # cv2.namedWindow('points', 0)
  268. # # for p in points:
  269. # # cv2.circle(img, (p[0], p[1]), 3, (0, 0, 255))
  270. # # cv2.imshow("points", img)
  271. # # cv2.waitKey(0)
  272. #
  273. # # 查看是否正确输出区域分割线
  274. # # cv2.namedWindow('split_lines', 0)
  275. # # for line in split_lines:
  276. # # cv2.line(img, line[0], line[1], (0, 0, 255), 2)
  277. # # cv2.imshow("split_lines", img)
  278. # # cv2.waitKey(0)
  279. #
  280. # # 获取每个表格的左上右下两个点
  281. # outline_points = get_outline_point(points, split_y)
  282. # # log("into otr_interface table_detect 14")
  283. #
  284. # if bboxes:
  285. # print("bboxes number", len(bboxes))
  286. # # print("bboxes", bboxes)
  287. # else:
  288. # print("bboxes number", "None")
  289. # log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  290. # log("otr finish: " + str(round(float(time.time()-start_time1), 4)) + "s")
  291. # return {"points": str(points), "split_lines": str(split_lines),
  292. # "bboxes": str(bboxes), "outline_points": str(outline_points),
  293. # "lines": str(rows+cols)}
  294. #
  295. # except TimeoutError:
  296. # raise TimeoutError
  297. # except Exception as e:
  298. # log("otr_interface cannot detected table!")
  299. # print("otr_interface cannot detected table!", traceback.print_exc())
  300. # print("points", 0, "split_lines", 0, "bboxes", 0)
  301. # log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  302. # return {"points": str([]), "split_lines": str([]), "bboxes": str([]),
  303. # "outline_points": str([]), "lines": str([])}
  304. def line_detect(img_data, otr_model, prob=0.2):
  305. log("into otr_interface table_detect")
  306. log('line_detect prob ' + str(prob))
  307. start_time = time.time()
  308. try:
  309. start_time1 = time.time()
  310. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  311. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  312. # log("into otr_interface table_detect 1")
  313. # cv2.imwrite("111111.jpg", img)
  314. # 将bgr转为rbg
  315. image_np = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  316. # log("into otr_interface table_detect 2")
  317. # 选择与图片最接近分辨率,以防失真
  318. # best_h, best_w = get_best_predict_size(img)
  319. log("image_np.shape" + str(image_np.shape))
  320. best_h, best_w, _ = image_np.shape
  321. log("otr preprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  322. # 调用模型
  323. # rows, cols = table_line(image_np, otr_model)
  324. start_time1 = time.time()
  325. list_line = table_line(image_np, otr_model, size=(best_w, best_h), prob=prob)
  326. log("otr finish " + str(round(float(time.time()-start_time1), 4)) + "s")
  327. return {"list_line": str(list_line)}
  328. except TimeoutError:
  329. raise TimeoutError
  330. except Exception as e:
  331. log("otr_interface cannot detected table!")
  332. print("otr_interface cannot detected table!", traceback.print_exc())
  333. log("otr postprocess time: " + str(round(float(time.time()-start_time1), 4)) + "s")
  334. return {"list_line": str([])}
  335. class OtrModels:
  336. def __init__(self):
  337. # python文件所在目录
  338. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  339. model_path = _dir + "/models/table-line.h5"
  340. self.otr_model = table_net((None, None, 3), 2)
  341. self.otr_model.load_weights(model_path)
  342. def get_model(self):
  343. return self.otr_model
  344. def test_otr_model(from_remote=True):
  345. _global._init()
  346. from format_convert.convert_image import get_best_predict_size, image_process
  347. if get_platform() == "Windows":
  348. file_path = "C:/Users/Administrator/Desktop/error2.png"
  349. file_path = "C:/Users/Administrator/Downloads/1652672734044.jpg"
  350. else:
  351. file_path = "1.jpg"
  352. image_np = cv2.imread(file_path)
  353. best_h, best_w = get_best_predict_size(image_np)
  354. image_resize = cv2.resize(image_np, (best_w, best_h), interpolation=cv2.INTER_AREA)
  355. cv2.imwrite(file_path, image_resize)
  356. with open(file_path, "rb") as f:
  357. file_bytes = f.read()
  358. file_base64 = base64.b64encode(file_bytes)
  359. _md5 = get_md5_from_bytes(file_bytes)[0]
  360. _global.update({"port": 15010, "md5": _md5})
  361. if from_remote:
  362. file_json = {"data": file_base64, "is_from_pdf": False, "md5": _md5}
  363. # _url = "http://192.168.2.104:18000/otr"
  364. _url = "http://127.0.0.1:18000/otr"
  365. r = json.loads(request_post(_url, file_json))
  366. else:
  367. # otr_model = OtrModels().get_model()
  368. # r = otr(file_base64, otr_model, is_from_pdf=False)
  369. r = image_process(image_resize, file_path)
  370. print(r)
  371. if __name__ == '__main__':
  372. if len(sys.argv) == 2:
  373. port = int(sys.argv[1])
  374. elif len(sys.argv) == 3:
  375. port = int(sys.argv[1])
  376. using_gpu_index = int(sys.argv[2])
  377. else:
  378. port = 18000
  379. using_gpu_index = 0
  380. # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
  381. app.run()
  382. log("OTR running "+str(port))
  383. # test_otr_model(False)
  384. # print(json.dumps([-2]))
  385. # otr_model = OtrModels().get_model()
  386. # otr("11", otr_model)