ocr_interface.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import base64
  2. import json
  3. import multiprocessing as mp
  4. import socket
  5. import sys
  6. import os
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. import time
  9. import traceback
  10. from multiprocessing.context import Process
  11. import cv2
  12. import requests
  13. import logging
  14. import numpy as np
  15. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  16. from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes
  17. from flask import Flask, request
  18. from format_convert import _global
  19. # 接口配置
  20. app = Flask(__name__)
  21. @app.route('/ocr', methods=['POST'])
  22. def _ocr():
  23. _global._init()
  24. _global.update({"port": globals().get("port")})
  25. start_time = time.time()
  26. log("into ocr_interface _ocr")
  27. try:
  28. if not request.form:
  29. log("ocr no data!")
  30. return json.dumps({"text": str([-9]), "bbox": str([-9])})
  31. data = request.form.get("data")
  32. # img_data = base64.b64decode(data)
  33. # _md5 = get_md5_from_bytes(img_data)[0]
  34. _md5 = request.form.get("md5")
  35. _global.update({"md5": _md5})
  36. ocr_model = globals().get("global_ocr_model")
  37. if ocr_model is None:
  38. log("----------- init ocr_model ------------")
  39. ocr_model = OcrModels().get_model()
  40. globals().update({"global_ocr_model": ocr_model})
  41. # text = picture2text(img_data, ocr_model)
  42. text = ocr(data, ocr_model)
  43. return json.dumps(text)
  44. except TimeoutError:
  45. return json.dumps({"text": str([-5]), "bbox": str([-5])})
  46. except:
  47. traceback.print_exc()
  48. return json.dumps({"text": str([-1]), "bbox": str([-1])})
  49. finally:
  50. log("ocr interface finish time " + str(time.time()-start_time))
  51. def ocr(data, ocr_model):
  52. log("into ocr_interface ocr")
  53. try:
  54. img_data = base64.b64decode(data)
  55. text = picture2text(img_data, ocr_model)
  56. return text
  57. except TimeoutError:
  58. return {"text": str([-5]), "bbox": str([-5])}
  59. flag = 0
  60. def picture2text(img_data, ocr_model):
  61. log("into ocr_interface picture2text")
  62. try:
  63. start_time = time.time()
  64. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  65. img = cv2.imdecode(np.frombuffer(img_data, np.uint8), cv2.IMREAD_COLOR)
  66. # 将bgr转为rbg
  67. try:
  68. np_images = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  69. except cv2.error as e:
  70. if "src.empty()" in str(e):
  71. log("ocr_interface picture2text image is empty!")
  72. return {"text": str([]), "bbox": str([])}
  73. # resize
  74. # cv2.imshow("before resize", np_images)
  75. # print("np_images.shape", np_images.shape)
  76. # best_h, best_w = get_best_predict_size(np_images)
  77. # np_images = cv2.resize(np_images, (best_w, best_h), interpolation=cv2.INTER_AREA)
  78. # cv2.imshow("after resize", np_images)
  79. # print("np_images.shape", np_images.shape)
  80. # cv2.waitKey(0)
  81. # 预测
  82. results = ocr_model.ocr(np_images, det=True, rec=True, cls=False)
  83. # 循环每张图片识别结果
  84. text_list = []
  85. bbox_list = []
  86. for line in results:
  87. # print("ocr_interface line", line)
  88. text_list.append(line[-1][0])
  89. bbox_list.append(line[0])
  90. # 查看bbox
  91. # img = np.zeros((np_images.shape[1], np_images.shape[0]), np.uint8)
  92. # img.fill(255)
  93. # for box in bbox_list:
  94. # print(box)
  95. # cv2.rectangle(img, (int(box[0][0]), int(box[0][1])),
  96. # (int(box[2][0]), int(box[2][1])), (0, 0, 255), 1)
  97. # cv2.imshow("bbox", img)
  98. # cv2.waitKey(0)
  99. # log("ocr model use time: " + str(time.time()-start_time))
  100. return {"text": str(text_list), "bbox": str(bbox_list)}
  101. except TimeoutError:
  102. raise TimeoutError
  103. except Exception as e:
  104. log("picture2text error!")
  105. print("picture2text", traceback.print_exc())
  106. return {"text": str([]), "bbox": str([])}
  107. def get_best_predict_size(image_np):
  108. sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
  109. min_len = 10000
  110. best_height = sizes[0]
  111. for height in sizes:
  112. if abs(image_np.shape[0] - height) < min_len:
  113. min_len = abs(image_np.shape[0] - height)
  114. best_height = height
  115. min_len = 10000
  116. best_width = sizes[0]
  117. for width in sizes:
  118. if abs(image_np.shape[1] - width) < min_len:
  119. min_len = abs(image_np.shape[1] - width)
  120. best_width = width
  121. return best_height, best_width
  122. class OcrModels:
  123. def __init__(self):
  124. from ocr.paddleocr import PaddleOCR
  125. try:
  126. self.ocr_model = PaddleOCR(use_angle_cls=True, lang="ch")
  127. except:
  128. print(traceback.print_exc())
  129. raise RuntimeError
  130. def get_model(self):
  131. return self.ocr_model
  132. def test_ocr_model(from_remote=True):
  133. file_path = "C:/Users/Administrator/Desktop/error2.png"
  134. with open(file_path, "rb") as f:
  135. file_bytes = f.read()
  136. file_base64 = base64.b64encode(file_bytes)
  137. _md5 = get_md5_from_bytes(file_bytes)[0]
  138. _global._init()
  139. _global.update({"port": 15010, "md5": _md5})
  140. if from_remote:
  141. file_json = {"data": file_base64, "md5": _md5}
  142. # _url = "http://192.168.2.102:17000/ocr"
  143. _url = "http://127.0.0.1:17000/ocr"
  144. print(json.loads(request_post(_url, file_json)))
  145. else:
  146. ocr_model = OcrModels().get_model()
  147. result = ocr(file_base64, ocr_model)
  148. print(result)
  149. if __name__ == '__main__':
  150. # if len(sys.argv) == 2:
  151. # port = int(sys.argv[1])
  152. # elif len(sys.argv) == 3:
  153. # port = int(sys.argv[1])
  154. # using_gpu_index = int(sys.argv[2])
  155. # else:
  156. # port = 17000
  157. # using_gpu_index = 0
  158. # _global._init()
  159. # _global.update({"port": str(port)})
  160. # globals().update({"port": str(port)})
  161. #
  162. # # ip = get_intranet_ip()
  163. # # logging.basicConfig(level=logging.INFO,
  164. # # format='%(asctime)s - %(name)s - %(levelname)s - '
  165. # # + ip + ' - ' + str(port) + ' - %(message)s')
  166. #
  167. # os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
  168. #
  169. # # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
  170. # app.run()
  171. # log("OCR running "+str(port))
  172. # test_ocr_model(False)
  173. #
  174. # log("OCR running")
  175. file_path = "C:/Users/Administrator/Desktop/test_image/error19.png"
  176. with open(file_path, "rb") as f:
  177. file_bytes = f.read()
  178. file_base64 = base64.b64encode(file_bytes)
  179. ocr_model = OcrModels().get_model()
  180. result = ocr(file_base64, ocr_model)
  181. text_list = eval(result.get("text"))
  182. box_list = eval(result.get("bbox"))
  183. print(text_list)
  184. #
  185. # new_list = []
  186. # for i in range(len(text_list)):
  187. # new_list.append([text_list[i], box_list[i]])
  188. #
  189. # # print(new_list[0][1])
  190. # new_list.sort(key=lambda x: (x[1][1][0], x[1][0][0]))
  191. #
  192. # for t in new_list:
  193. # print(t[0])