ocr_interface.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import base64
  2. import io
  3. import json
  4. import multiprocessing as mp
  5. import socket
  6. import sys
  7. import os
  8. from PIL import Image
  9. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  10. import time
  11. import traceback
  12. from multiprocessing.context import Process
  13. import cv2
  14. import requests
  15. import logging
  16. import numpy as np
  17. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  18. from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes, bytes2np
  19. from flask import Flask, request
  20. from format_convert import _global
  21. # 接口配置
  22. app = Flask(__name__)
  23. use_angle_cls = False
  24. @app.route('/ocr', methods=['POST'])
  25. def _ocr():
  26. _global._init()
  27. _global.update({"port": globals().get("port")})
  28. start_time = time.time()
  29. log("into ocr_interface _ocr")
  30. try:
  31. if not request.form:
  32. log("ocr no data!")
  33. return json.dumps({"text": str([-9]), "bbox": str([-9])})
  34. data = request.form.get("data")
  35. _md5 = request.form.get("md5")
  36. only_rec = request.form.get("only_rec")
  37. if only_rec is None:
  38. only_rec = 0
  39. else:
  40. only_rec = int(only_rec)
  41. _global.update({"md5": _md5})
  42. ocr_model = globals().get("global_ocr_model")
  43. if ocr_model is None:
  44. log("----------- init ocr_model ------------")
  45. ocr_model = OcrModels().get_model()
  46. globals().update({"global_ocr_model": ocr_model})
  47. text = ocr(data, ocr_model, only_rec)
  48. return json.dumps(text)
  49. except TimeoutError:
  50. return json.dumps({"text": str([-5]), "bbox": str([-5])})
  51. except:
  52. traceback.print_exc()
  53. return json.dumps({"text": str([-1]), "bbox": str([-1])})
  54. finally:
  55. log("ocr interface finish time " + str(time.time()-start_time))
  56. def ocr(data, ocr_model, only_rec=0):
  57. log("into ocr_interface ocr")
  58. try:
  59. img_data = base64.b64decode(data)
  60. text = picture2text(img_data, ocr_model, only_rec)
  61. return text
  62. except TimeoutError:
  63. return {"text": str([-5]), "bbox": str([-5])}
  64. def picture2text(img_data, ocr_model, only_rec=0):
  65. log("into ocr_interface picture2text")
  66. try:
  67. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  68. img = bytes2np(img_data)
  69. # cv2.imwrite('ocr.jpg', img)
  70. # 预测
  71. if only_rec:
  72. results = ocr_model.ocr(img, det=False, rec=True, cls=use_angle_cls)
  73. else:
  74. results = ocr_model.ocr(img, det=True, rec=True, cls=use_angle_cls)
  75. # 循环每张图片识别结果
  76. text_list = []
  77. bbox_list = []
  78. if only_rec:
  79. text_list = [results[0][0]]
  80. bbox_list = []
  81. else:
  82. for line in results:
  83. text_list.append(line[-1][0])
  84. bbox_list.append(line[0])
  85. return {"text": str(text_list), "bbox": str(bbox_list)}
  86. except TimeoutError:
  87. raise TimeoutError
  88. except Exception:
  89. log("picture2text error!")
  90. traceback.print_exc()
  91. return {"text": str([]), "bbox": str([])}
  92. def get_best_predict_size(image_np):
  93. sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
  94. min_len = 10000
  95. best_height = sizes[0]
  96. for height in sizes:
  97. if abs(image_np.shape[0] - height) < min_len:
  98. min_len = abs(image_np.shape[0] - height)
  99. best_height = height
  100. min_len = 10000
  101. best_width = sizes[0]
  102. for width in sizes:
  103. if abs(image_np.shape[1] - width) < min_len:
  104. min_len = abs(image_np.shape[1] - width)
  105. best_width = width
  106. return best_height, best_width
  107. class OcrModels:
  108. def __init__(self):
  109. from ocr.paddleocr import PaddleOCR
  110. try:
  111. log('----------- init ocr model ---------------')
  112. self.ocr_model = PaddleOCR(use_angle_cls=use_angle_cls, lang="ch")
  113. except:
  114. print(traceback.print_exc())
  115. raise RuntimeError
  116. def get_model(self):
  117. return self.ocr_model
  118. def test_ocr_model(from_remote=True):
  119. file_path = "error8.png"
  120. file_path = "C:/Users/Administrator/Downloads/dbf46fe38862ac03209f1b2c12b1adc1.jpg"
  121. with open(file_path, "rb") as f:
  122. file_bytes = f.read()
  123. src = """
  124. 
  125. """
  126. image_data = src.split('data:image/png;base64,')[1]
  127. # 解码 base64 字符串
  128. file_bytes = base64.b64decode(image_data)
  129. file_base64 = base64.b64encode(file_bytes)
  130. _md5 = get_md5_from_bytes(file_bytes)[0]
  131. only_rec = 0
  132. _global._init()
  133. _global.update({"port": 15010, "md5": _md5})
  134. if from_remote:
  135. file_json = {"data": file_base64, "md5": _md5, 'only_rec': only_rec}
  136. # _url = "http://192.168.2.102:17000/ocr"
  137. # _url = "http://127.0.0.1:17000/ocr"
  138. _url = "http://120.132.118.205:17000/ocr"
  139. print(json.loads(request_post(_url, file_json)))
  140. else:
  141. ocr_model = OcrModels().get_model()
  142. result = ocr(file_base64, ocr_model, only_rec=only_rec)
  143. text = result.get('text')
  144. bbox = result.get('bbox')
  145. print('bbox', bbox)
  146. print(result)
  147. if __name__ == '__main__':
  148. test_ocr_model()
  149. # src = """
  150. # 
  151. # """
  152. #
  153. # image_data = src.split('data:image/png;base64,')[1]
  154. #
  155. # # 解码 base64 字符串
  156. # image_bytes = base64.b64decode(image_data)
  157. #
  158. # # 将字节转换为图像
  159. # # image = Image.open(io.BytesIO(image_bytes))
  160. #
  161. # # image.show('img')
  162. #
  163. # # with open(r'C:\Users\Administrator\Desktop\test_image\error16.jpg', 'rb') as f:
  164. # # image_bytes = f.read()
  165. #
  166. # image = bytes2np(image_bytes)
  167. #
  168. # cv2.imshow('img', image)
  169. # cv2.imwrite('./1.png', image)
  170. # cv2.waitKey(0)