ocr_interface.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import base64
  2. import json
  3. import multiprocessing as mp
  4. import socket
  5. import sys
  6. import os
  7. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  8. import time
  9. import traceback
  10. from multiprocessing.context import Process
  11. import cv2
  12. import requests
  13. import logging
  14. import numpy as np
  15. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  16. from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes, bytes2np
  17. from flask import Flask, request
  18. from format_convert import _global
  19. # 接口配置
  20. app = Flask(__name__)
  21. use_angle_cls = False
  22. @app.route('/ocr', methods=['POST'])
  23. def _ocr():
  24. _global._init()
  25. _global.update({"port": globals().get("port")})
  26. start_time = time.time()
  27. log("into ocr_interface _ocr")
  28. try:
  29. if not request.form:
  30. log("ocr no data!")
  31. return json.dumps({"text": str([-9]), "bbox": str([-9])})
  32. data = request.form.get("data")
  33. _md5 = request.form.get("md5")
  34. only_rec = request.form.get("only_rec")
  35. if only_rec is None:
  36. only_rec = 0
  37. else:
  38. only_rec = int(only_rec)
  39. _global.update({"md5": _md5})
  40. ocr_model = globals().get("global_ocr_model")
  41. if ocr_model is None:
  42. log("----------- init ocr_model ------------")
  43. ocr_model = OcrModels().get_model()
  44. globals().update({"global_ocr_model": ocr_model})
  45. text = ocr(data, ocr_model, only_rec)
  46. return json.dumps(text)
  47. except TimeoutError:
  48. return json.dumps({"text": str([-5]), "bbox": str([-5])})
  49. except:
  50. traceback.print_exc()
  51. return json.dumps({"text": str([-1]), "bbox": str([-1])})
  52. finally:
  53. log("ocr interface finish time " + str(time.time()-start_time))
  54. def ocr(data, ocr_model, only_rec=0):
  55. log("into ocr_interface ocr")
  56. try:
  57. img_data = base64.b64decode(data)
  58. text = picture2text(img_data, ocr_model, only_rec)
  59. return text
  60. except TimeoutError:
  61. return {"text": str([-5]), "bbox": str([-5])}
  62. def picture2text(img_data, ocr_model, only_rec=0):
  63. log("into ocr_interface picture2text")
  64. try:
  65. # 二进制数据流转np.ndarray [np.uint8: 8位像素]
  66. img = bytes2np(img_data)
  67. # cv2.imwrite('ocr.jpg', img)
  68. # 预测
  69. if only_rec:
  70. results = ocr_model.ocr(img, det=False, rec=True, cls=use_angle_cls)
  71. else:
  72. results = ocr_model.ocr(img, det=True, rec=True, cls=use_angle_cls)
  73. # 循环每张图片识别结果
  74. text_list = []
  75. bbox_list = []
  76. if only_rec:
  77. text_list = [results[0][0]]
  78. bbox_list = []
  79. else:
  80. for line in results:
  81. text_list.append(line[-1][0])
  82. bbox_list.append(line[0])
  83. return {"text": str(text_list), "bbox": str(bbox_list)}
  84. except TimeoutError:
  85. raise TimeoutError
  86. except Exception:
  87. log("picture2text error!")
  88. traceback.print_exc()
  89. return {"text": str([]), "bbox": str([])}
  90. def get_best_predict_size(image_np):
  91. sizes = [1280, 1152, 1024, 896, 768, 640, 512, 384, 256, 128]
  92. min_len = 10000
  93. best_height = sizes[0]
  94. for height in sizes:
  95. if abs(image_np.shape[0] - height) < min_len:
  96. min_len = abs(image_np.shape[0] - height)
  97. best_height = height
  98. min_len = 10000
  99. best_width = sizes[0]
  100. for width in sizes:
  101. if abs(image_np.shape[1] - width) < min_len:
  102. min_len = abs(image_np.shape[1] - width)
  103. best_width = width
  104. return best_height, best_width
  105. class OcrModels:
  106. def __init__(self):
  107. from ocr.paddleocr import PaddleOCR
  108. try:
  109. log('----------- init ocr model ---------------')
  110. self.ocr_model = PaddleOCR(use_angle_cls=use_angle_cls, lang="ch")
  111. except:
  112. print(traceback.print_exc())
  113. raise RuntimeError
  114. def get_model(self):
  115. return self.ocr_model
  116. def test_ocr_model(from_remote=True):
  117. file_path = "error8.png"
  118. with open(file_path, "rb") as f:
  119. file_bytes = f.read()
  120. file_base64 = base64.b64encode(file_bytes)
  121. _md5 = get_md5_from_bytes(file_bytes)[0]
  122. only_rec = False
  123. _global._init()
  124. _global.update({"port": 15010, "md5": _md5})
  125. if from_remote:
  126. file_json = {"data": file_base64, "md5": _md5, 'only_rec': only_rec}
  127. # _url = "http://192.168.2.102:17000/ocr"
  128. _url = "http://127.0.0.1:17000/ocr"
  129. print(json.loads(request_post(_url, file_json)))
  130. else:
  131. ocr_model = OcrModels().get_model()
  132. result = ocr(file_base64, ocr_model, only_rec=only_rec)
  133. print(result)
  134. if __name__ == '__main__':
  135. test_ocr_model(False)