otr_gpu_interface.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # encoding=utf8
  2. import base64
  3. import io
  4. import json
  5. import os
  6. import pickle
  7. import threading
  8. import traceback
  9. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  10. # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
  11. import redis
  12. import tensorflow as tf
  13. try:
  14. gpus = tf.config.list_physical_devices('GPU')
  15. if len(gpus) > 0:
  16. tf.config.experimental.set_virtual_device_configuration(
  17. gpus[0],
  18. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)])
  19. except:
  20. traceback.print_exc()
  21. pass
  22. import sys
  23. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  24. import time
  25. import logging
  26. # from table_line import *
  27. import cv2
  28. import numpy as np
  29. from flask import Flask, request
  30. from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform
  31. from otr.table_line import table_net, table_line, table_preprocess, table_postprocess
  32. from format_convert import _global
  33. # 接口配置
  34. app = Flask(__name__)
  35. lock = threading.RLock()
  36. redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
  37. db=1, password='bidi123456', health_check_interval=300)
  38. # @app.route('/otr', methods=['POST'])
  39. def _otr_flask():
  40. start_time = time.time()
  41. log("into _otr")
  42. _global._init()
  43. _global.update({"port": globals().get("port")})
  44. log("into _otr -> _global " + str(time.time()-start_time))
  45. start_time = time.time()
  46. try:
  47. if not request.form:
  48. log("otr no data!")
  49. return json.dumps({"list_line": str([-9])})
  50. log("judge request.form " + str(time.time()-start_time))
  51. start_time1 = time.time()
  52. # 反序列化
  53. result = pickle.loads(base64.b64decode(request.form.get("data")))
  54. inputs = result.get("inputs")
  55. # 解压numpy
  56. decompressed_array = io.BytesIO()
  57. decompressed_array.write(inputs)
  58. decompressed_array.seek(0)
  59. inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
  60. log("inputs.shape" + str(inputs.shape))
  61. predictor_type = result.get("predictor_type")
  62. model_type = result.get("model_type")
  63. _md5 = result.get("md5")
  64. _global.update({"md5": _md5})
  65. log("read data " + str(time.time()-start_time1))
  66. # 获取模型
  67. model = globals().get(model_type)
  68. if model is None:
  69. start_time1 = time.time()
  70. log("=== init " + model_type + " model ===")
  71. model = OtrModels().get_model()
  72. globals().update({model_type: model})
  73. log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  74. # 运行
  75. with lock:
  76. start_time1 = time.time()
  77. pred = model.predict(inputs)
  78. pred = pred[0]
  79. log("pred.shape " + str(pred.shape))
  80. # 压缩numpy
  81. compressed_array = io.BytesIO()
  82. np.savez_compressed(compressed_array, pred)
  83. compressed_array.seek(0)
  84. pred = compressed_array.read()
  85. gpu_time = round(float(time.time()-start_time1), 2)
  86. finish_time = round(float(time.time()-start_time), 2)
  87. log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
  88. return base64.b64encode(pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
  89. except Exception as e:
  90. finish_time = round(float(time.time()-start_time), 2)
  91. traceback.print_exc()
  92. return base64.b64encode(pickle.dumps({"preds": None, "gpu_time": 0., "elapse": finish_time}))
  93. def _otr():
  94. start_time = time.time()
  95. log("into _otr")
  96. _global._init()
  97. _global.update({"port": globals().get("port")})
  98. log("into _otr -> _global " + str(time.time()-start_time))
  99. while True:
  100. start_time = time.time()
  101. try:
  102. if redis_db.llen("producer_otr") == 0:
  103. continue
  104. log("judge llen " + str(time.time()-start_time))
  105. _time = time.time()
  106. result = redis_db.lpop("producer_otr")
  107. if result is None:
  108. continue
  109. result = pickle.loads(result)
  110. log("from producer_otr time " + str(time.time() - _time))
  111. _time = time.time()
  112. inputs = result.get("inputs")
  113. # # 解压numpy
  114. # decompressed_array = io.BytesIO()
  115. # decompressed_array.write(inputs)
  116. # decompressed_array.seek(0)
  117. # inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
  118. # log("inputs.shape " + str(inputs.shape))
  119. # log("numpy decompress " + str(time.time()-_time))
  120. predictor_type = result.get("predictor_type")
  121. _uuid = result.get("uuid")
  122. model_type = result.get("model_type")
  123. _md5 = result.get("md5")
  124. _global.update({"md5": _md5})
  125. log("read data " + str(time.time()-_time))
  126. # 获取模型
  127. model = globals().get(model_type)
  128. if model is None:
  129. start_time1 = time.time()
  130. log("=== init " + model_type + " model ===")
  131. model = OtrModels().get_model()
  132. globals().update({model_type: model})
  133. log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  134. # 运行
  135. start_time1 = time.time()
  136. pred = model.predict(inputs)
  137. pred = pred[0]
  138. log("pred.shape " + str(pred.shape))
  139. # # 压缩numpy
  140. # _time = time.time()
  141. # compressed_array = io.BytesIO()
  142. # np.savez_compressed(compressed_array, pred)
  143. # compressed_array.seek(0)
  144. # pred = compressed_array.read()
  145. # log("numpy compress " + str(time.time()-_time))
  146. # 写入redis
  147. gpu_time = round(float(time.time()-start_time1), 2)
  148. finish_time = round(float(time.time()-start_time), 2)
  149. redis_db.hset("consumer_otr", _uuid, pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
  150. log("to consumer_otr " + str(time.time()-_time))
  151. log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
  152. except Exception as e:
  153. traceback.print_exc()
  154. class OtrModels:
  155. def __init__(self):
  156. # python文件所在目录
  157. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  158. model_path = _dir + "/models/table-line.h5"
  159. self.otr_model = table_net((None, None, 3), 2)
  160. self.otr_model.load_weights(model_path)
  161. def get_model(self):
  162. return self.otr_model
  163. if __name__ == '__main__':
  164. if len(sys.argv) == 2:
  165. port = int(sys.argv[1])
  166. using_gpu_index = 0
  167. elif len(sys.argv) == 3:
  168. port = int(sys.argv[1])
  169. using_gpu_index = int(sys.argv[2])
  170. else:
  171. port = 18000
  172. using_gpu_index = 0
  173. # _global._init()
  174. # _global.update({"port": str(port)})
  175. # globals().update({"port": str(port)})
  176. # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
  177. # app.run()
  178. # log("OTR running "+str(port))
  179. _otr()