ocr_interface.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import base64
  2. import io
  3. import json
  4. import multiprocessing as mp
  5. import socket
  6. import sys
  7. import os
  8. os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
  9. from PIL import Image
  10. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  11. import time
  12. import traceback
  13. from multiprocessing.context import Process
  14. import cv2
  15. import requests
  16. import logging
  17. import numpy as np
  18. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  19. from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes, bytes2np
  20. from flask import Flask, request
  21. from format_convert import _global
  22. # 接口配置
  23. app = Flask(__name__)
  24. use_angle_cls = False
  25. @app.route('/ocr', methods=['POST'])
  26. def _ocr():
  27. _global._init()
  28. _global.update({"port": globals().get("port")})
  29. start_time = time.time()
  30. log("into ocr_interface _ocr")
  31. try:
  32. if not request.form:
  33. log("ocr no data!")
  34. return json.dumps({"text": str([-9]), "bbox": str([-9])})
  35. data = request.form.get("data")
  36. _md5 = request.form.get("md5")
  37. only_rec = request.form.get("only_rec")
  38. if only_rec is None:
  39. only_rec = 0
  40. else:
  41. only_rec = int(only_rec)
  42. _global.update({"md5": _md5})
  43. ocr_model = globals().get("global_ocr_model")
  44. if ocr_model is None:
  45. log("----------- init ocr_model ------------")
  46. ocr_model = OcrModels().get_model()
  47. globals().update({"global_ocr_model": ocr_model})
  48. text = ocr(data, ocr_model, only_rec)
  49. return json.dumps(text)
  50. except TimeoutError:
  51. return json.dumps({"text": str([-5]), "bbox": str([-5])})
  52. except:
  53. traceback.print_exc()
  54. return json.dumps({"text": str([-1]), "bbox": str([-1])})
  55. finally:
  56. log("ocr interface finish time " + str(time.time()-start_time))
  57. def ocr(data, ocr_model, only_rec=0):
  58. log("into ocr_interface ocr")
  59. try:
  60. img_data = base64.b64decode(data)
  61. text = picture2text(img_data, ocr_model, only_rec)
  62. return text
  63. except TimeoutError:
  64. return {"text": str([-5]), "bbox": str([-5])}
  65. def picture2text(img_data, ocr_model, only_rec=0):
  66. log("into ocr_interface picture2text")
  67. try:
  68. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  69. img = bytes2np(img_data)
  70. # cv2.imwrite('ocr.jpg', img)
  71. # 预测
  72. if only_rec:
  73. results = ocr_model.ocr(img, det=False, rec=True, cls=use_angle_cls)
  74. else:
  75. results = ocr_model.ocr(img, det=True, rec=True, cls=use_angle_cls)
  76. # 循环每张图片识别结果
  77. text_list = []
  78. bbox_list = []
  79. if only_rec:
  80. if results:
  81. text_list = [results[0][0]]
  82. else:
  83. text_list = []
  84. bbox_list = []
  85. else:
  86. for line in results:
  87. text_list.append(line[-1][0])
  88. bbox_list.append(line[0])
  89. return {"text": str(text_list), "bbox": str(bbox_list)}
  90. except TimeoutError:
  91. raise TimeoutError
  92. except Exception:
  93. log("picture2text error!")
  94. traceback.print_exc()
  95. return {"text": str([]), "bbox": str([])}
  96. def get_best_predict_size(image_np):
  97. sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
  98. min_len = 10000
  99. best_height = sizes[0]
  100. for height in sizes:
  101. if abs(image_np.shape[0] - height) < min_len:
  102. min_len = abs(image_np.shape[0] - height)
  103. best_height = height
  104. min_len = 10000
  105. best_width = sizes[0]
  106. for width in sizes:
  107. if abs(image_np.shape[1] - width) < min_len:
  108. min_len = abs(image_np.shape[1] - width)
  109. best_width = width
  110. return best_height, best_width
  111. class OcrModels:
  112. def __init__(self):
  113. from ocr.paddleocr import PaddleOCR
  114. try:
  115. log('----------- init ocr model ---------------')
  116. self.ocr_model = PaddleOCR(use_angle_cls=use_angle_cls, lang="ch")
  117. except:
  118. print(traceback.print_exc())
  119. raise RuntimeError
  120. def get_model(self):
  121. return self.ocr_model
  122. def test_ocr_model(from_remote=True):
  123. file_path = "error8.png"
  124. file_path = "C:/Users/Administrator/Downloads/dbf46fe38862ac03209f1b2c12b1adc1.jpg"
  125. with open(file_path, "rb") as f:
  126. file_bytes = f.read()
  127. src = """
  128. 
  129. """
  130. image_data = src.split('data:image/png;base64,')[1]
  131. # 解码 base64 字符串
  132. file_bytes = base64.b64decode(image_data)
  133. file_base64 = base64.b64encode(file_bytes)
  134. _md5 = get_md5_from_bytes(file_bytes)[0]
  135. only_rec = 0
  136. _global._init()
  137. _global.update({"port": 15010, "md5": _md5})
  138. if from_remote:
  139. file_json = {"data": file_base64, "md5": _md5, 'only_rec': only_rec}
  140. # _url = "http://192.168.2.102:17000/ocr"
  141. # _url = "http://127.0.0.1:17000/ocr"
  142. _url = "http://120.132.118.205:17000/ocr"
  143. print(json.loads(request_post(_url, file_json)))
  144. else:
  145. ocr_model = OcrModels().get_model()
  146. result = ocr(file_base64, ocr_model, only_rec=only_rec)
  147. text = result.get('text')
  148. bbox = result.get('bbox')
  149. print('bbox', bbox)
  150. print(result)
  151. if __name__ == '__main__':
  152. # test_ocr_model()
  153. app.run(host='127.0.0.1', port=17000, debug=False)