ocr_gpu_interface.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # encoding=utf8
  2. import base64
  3. import io
  4. import json
  5. import pickle
  6. import sys
  7. import os
  8. import threading
  9. import zlib
  10. import numpy as np
  11. import redis
  12. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  13. import time
  14. import traceback
  15. os.environ['FLAGS_eager_delete_tensor_gb'] = '0'
  16. from format_convert.utils import request_post, test_gpu, get_intranet_ip, log, get_md5_from_bytes
  17. from flask import Flask, request
  18. from format_convert import _global
  19. from ocr.tools.infer import utility
  20. from ocr.ppocr.utils.logging import get_logger
  21. logger = get_logger()
  22. # 接口配置
  23. app = Flask(__name__)
  24. lock = threading.RLock()
  25. redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
  26. db=1, password='bidi123456', health_check_interval=300)
  27. # @app.route('/ocr', methods=['POST'])
  28. def _ocr_gpu_flask():
  29. start_time = time.time()
  30. log("into _ocr")
  31. _global._init()
  32. _global.update({"port": globals().get("port")})
  33. log("into _ocr -> _global " + str(time.time()-start_time))
  34. start_time = time.time()
  35. try:
  36. if not request.form:
  37. log("ocr no data!")
  38. return json.dumps({"text": str([-9]), "bbox": str([-9])})
  39. log("judge request.form " + str(time.time()-start_time))
  40. start_time1 = time.time()
  41. result = pickle.loads(base64.b64decode(request.form.get("data")))
  42. # 解压
  43. inputs = result.get("inputs")
  44. # 解压numpy
  45. decompressed_array = io.BytesIO()
  46. decompressed_array.write(inputs)
  47. decompressed_array.seek(0)
  48. inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
  49. log("inputs.shape" + str(inputs.shape))
  50. args = result.get("args")
  51. predictor_type = result.get("predictor_type")
  52. model_type = result.get("model_type")
  53. _md5 = result.get("md5")
  54. _global.update({"md5": _md5})
  55. log("read data " + str(time.time()-start_time1))
  56. # 获取对应predictor
  57. if globals().get(predictor_type) is None:
  58. start_time1 = time.time()
  59. log("=== init " + model_type + " " + predictor_type + " model ===")
  60. predictor, input_tensor, output_tensors = \
  61. utility.create_predictor(args, predictor_type, logger)
  62. globals().update({predictor_type: {"predictor": predictor,
  63. "input_tensor": input_tensor,
  64. "output_tensors": output_tensors}})
  65. log("=== init " + model_type + " " + predictor_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  66. else:
  67. predictor = globals().get(predictor_type).get("predictor")
  68. input_tensor = globals().get(predictor_type).get("input_tensor")
  69. output_tensors = globals().get(predictor_type).get("output_tensors")
  70. # 设置模型输入,运行
  71. input_tensor.copy_from_cpu(inputs)
  72. with lock:
  73. start_time1 = time.time()
  74. predictor.run()
  75. gpu_time = round(float(time.time()-start_time1), 2)
  76. # 获取模型输出
  77. outputs = []
  78. for output_tensor in output_tensors:
  79. output = output_tensor.copy_to_cpu()
  80. outputs.append(output)
  81. preds = outputs[0]
  82. # 压缩numpy
  83. compressed_array = io.BytesIO()
  84. np.savez_compressed(compressed_array, preds)
  85. compressed_array.seek(0)
  86. preds = compressed_array.read()
  87. # 释放内存
  88. predictor.clear_intermediate_tensor()
  89. predictor.try_shrink_memory()
  90. finish_time = round(float(time.time()-start_time), 2)
  91. log("ocr model predict time - " + str(predictor_type) + " - " + str(gpu_time) + " " + str(finish_time))
  92. return base64.b64encode(pickle.dumps({"preds": preds, "gpu_time": gpu_time, "elapse": finish_time}))
  93. except Exception as e:
  94. finish_time = round(float(time.time()-start_time), 2)
  95. traceback.print_exc()
  96. return base64.b64encode(pickle.dumps({"preds": None, "gpu_time": 0., "elapse": finish_time}))
  97. def _ocr_gpu_redis():
  98. start_time = time.time()
  99. log("into _ocr")
  100. _global._init()
  101. _global.update({"port": globals().get("port")})
  102. log("into _ocr -> _global " + str(time.time()-start_time))
  103. while True:
  104. start_time = time.time()
  105. try:
  106. if redis_db.llen("producer_ocr") == 0:
  107. continue
  108. log("judge llen " + str(time.time()-start_time))
  109. _time = time.time()
  110. result = redis_db.lpop("producer_ocr")
  111. if result is None:
  112. continue
  113. result = pickle.loads(result)
  114. log("from producer_ocr time " + str(time.time() - _time))
  115. _time = time.time()
  116. inputs = result.get("inputs")
  117. # # 解压numpy
  118. # decompressed_array = io.BytesIO()
  119. # decompressed_array.write(inputs)
  120. # decompressed_array.seek(0)
  121. # inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
  122. # log("inputs.shape " + str(inputs.shape))
  123. # log("numpy decompress " + str(time.time()-_time))
  124. args = result.get("args")
  125. _uuid = result.get("uuid")
  126. predictor_type = result.get("predictor_type")
  127. model_type = result.get("model_type")
  128. _md5 = result.get("md5")
  129. _global.update({"md5": _md5})
  130. log("read data " + str(time.time()-_time))
  131. # 获取对应predictor
  132. if globals().get(predictor_type) is None:
  133. start_time1 = time.time()
  134. log("=== init " + model_type + " " + predictor_type + " model ===")
  135. predictor, input_tensor, output_tensors = \
  136. utility.create_predictor(args, predictor_type, logger)
  137. globals().update({predictor_type: {"predictor": predictor,
  138. "input_tensor": input_tensor,
  139. "output_tensors": output_tensors}})
  140. log("=== init " + model_type + " " + predictor_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  141. else:
  142. predictor = globals().get(predictor_type).get("predictor")
  143. input_tensor = globals().get(predictor_type).get("input_tensor")
  144. output_tensors = globals().get(predictor_type).get("output_tensors")
  145. # 设置模型输入,运行
  146. input_tensor.copy_from_cpu(inputs)
  147. start_time1 = time.time()
  148. predictor.run()
  149. gpu_time = round(float(time.time()-start_time1), 2)
  150. # 获取模型输出
  151. _time = time.time()
  152. outputs = []
  153. for output_tensor in output_tensors:
  154. output = output_tensor.copy_to_cpu()
  155. outputs.append(output)
  156. preds = outputs[0]
  157. log("output_tensors " + str(time.time()-_time))
  158. # # 压缩numpy
  159. # _time = time.time()
  160. # compressed_array = io.BytesIO()
  161. # np.savez_compressed(compressed_array, preds)
  162. # compressed_array.seek(0)
  163. # preds = compressed_array.read()
  164. # log("numpy compress " + str(time.time()-_time))
  165. # 写入redis
  166. finish_time = round(float(time.time()-start_time), 2)
  167. _time = time.time()
  168. redis_db.hset("consumer_ocr", _uuid, pickle.dumps({"preds": preds, "gpu_time": gpu_time, "elapse": finish_time}))
  169. log("to consumer_ocr " + str(time.time()-_time))
  170. # 释放内存
  171. predictor.clear_intermediate_tensor()
  172. predictor.try_shrink_memory()
  173. log("ocr model predict time - " + str(predictor_type) + " - " + str(gpu_time) + " " + str(finish_time))
  174. except Exception as e:
  175. traceback.print_exc()
  176. if __name__ == '__main__':
  177. if len(sys.argv) == 2:
  178. port = int(sys.argv[1])
  179. using_gpu_index = 0
  180. elif len(sys.argv) == 3:
  181. port = int(sys.argv[1])
  182. using_gpu_index = int(sys.argv[2])
  183. else:
  184. port = 17000
  185. using_gpu_index = 0
  186. _global._init()
  187. _global.update({"port": str(port)})
  188. globals().update({"port": str(port)})
  189. ip = get_intranet_ip()
  190. os.environ['CUDA_VISIBLE_DEVICES'] = str(using_gpu_index)
  191. # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
  192. app.run()
  193. # log("OCR running "+str(port))
  194. # _ocr()