otr_gpu_interface.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # encoding=utf8
  2. import base64
  3. import io
  4. import json
  5. import os
  6. import pickle
  7. import threading
  8. import traceback
  9. # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
  10. # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
  11. import redis
  12. import tensorflow as tf
  13. try:
  14. gpus = tf.config.list_physical_devices('GPU')
  15. if len(gpus) > 0:
  16. tf.config.experimental.set_virtual_device_configuration(
  17. gpus[0],
  18. [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)])
  19. except:
  20. traceback.print_exc()
  21. pass
  22. import sys
  23. sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
  24. import time
  25. import logging
  26. # from table_line import *
  27. import cv2
  28. import numpy as np
  29. from flask import Flask, request
  30. from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform, \
  31. to_share_memory, from_share_memory, get_np_type, get_share_memory_list, release_share_memory, get_share_memory, \
  32. close_share_memory_list
  33. from otr.table_line import table_net, table_line, table_preprocess, table_postprocess
  34. from format_convert import _global
  35. # 接口配置
  36. app = Flask(__name__)
  37. lock = threading.RLock()
  38. # redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
  39. # db=1, password='bidi123456', health_check_interval=300)
  40. redis_db = None
  41. # @app.route('/otr', methods=['POST'])
  42. def _otr_flask():
  43. start_time = time.time()
  44. log("into _otr")
  45. _global._init()
  46. _global.update({"port": globals().get("port")})
  47. log("into _otr -> _global " + str(time.time()-start_time))
  48. start_time = time.time()
  49. try:
  50. if not request.form:
  51. log("otr no data!")
  52. return json.dumps({"list_line": str([-9])})
  53. log("judge request.form " + str(time.time()-start_time))
  54. start_time1 = time.time()
  55. # 反序列化
  56. result = pickle.loads(base64.b64decode(request.form.get("data")))
  57. inputs = result.get("inputs")
  58. # 解压numpy
  59. decompressed_array = io.BytesIO()
  60. decompressed_array.write(inputs)
  61. decompressed_array.seek(0)
  62. inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
  63. log("inputs.shape" + str(inputs.shape))
  64. predictor_type = result.get("predictor_type")
  65. model_type = result.get("model_type")
  66. _md5 = result.get("md5")
  67. _global.update({"md5": _md5})
  68. log("read data " + str(time.time()-start_time1))
  69. # 获取模型
  70. model = globals().get(model_type)
  71. if model is None:
  72. start_time1 = time.time()
  73. log("=== init " + model_type + " model ===")
  74. model = OtrModels().get_model()
  75. globals().update({model_type: model})
  76. log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  77. # 运行
  78. with lock:
  79. start_time1 = time.time()
  80. pred = model.predict(inputs)
  81. pred = pred[0]
  82. log("pred.shape " + str(pred.shape))
  83. # 压缩numpy
  84. compressed_array = io.BytesIO()
  85. np.savez_compressed(compressed_array, pred)
  86. compressed_array.seek(0)
  87. pred = compressed_array.read()
  88. gpu_time = round(float(time.time()-start_time1), 2)
  89. finish_time = round(float(time.time()-start_time), 2)
  90. log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
  91. return base64.b64encode(pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
  92. except Exception as e:
  93. finish_time = round(float(time.time()-start_time), 2)
  94. traceback.print_exc()
  95. return base64.b64encode(pickle.dumps({"preds": None, "gpu_time": 0., "elapse": finish_time}))
  96. def _otr_redis():
  97. start_time = time.time()
  98. log("into _otr")
  99. _global._init()
  100. _global.update({"port": globals().get("port")})
  101. log("into _otr -> _global " + str(time.time()-start_time))
  102. while True:
  103. start_time = time.time()
  104. try:
  105. if redis_db.llen("producer_otr") == 0:
  106. continue
  107. log("judge llen " + str(time.time()-start_time))
  108. _time = time.time()
  109. result = redis_db.lpop("producer_otr")
  110. if result is None:
  111. continue
  112. result = pickle.loads(result)
  113. log("from producer_otr time " + str(time.time() - _time))
  114. _time = time.time()
  115. inputs = result.get("inputs")
  116. # # 解压numpy
  117. # decompressed_array = io.BytesIO()
  118. # decompressed_array.write(inputs)
  119. # decompressed_array.seek(0)
  120. # inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
  121. # log("inputs.shape " + str(inputs.shape))
  122. # log("numpy decompress " + str(time.time()-_time))
  123. predictor_type = result.get("predictor_type")
  124. _uuid = result.get("uuid")
  125. model_type = result.get("model_type")
  126. _md5 = result.get("md5")
  127. _global.update({"md5": _md5})
  128. log("read data " + str(time.time()-_time))
  129. # 获取模型
  130. model = globals().get(model_type)
  131. if model is None:
  132. start_time1 = time.time()
  133. log("=== init " + model_type + " model ===")
  134. model = OtrModels().get_model()
  135. globals().update({model_type: model})
  136. log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  137. # 运行
  138. start_time1 = time.time()
  139. pred = model.predict(inputs)
  140. pred = pred[0]
  141. log("pred.shape " + str(pred.shape))
  142. # # 压缩numpy
  143. # _time = time.time()
  144. # compressed_array = io.BytesIO()
  145. # np.savez_compressed(compressed_array, pred)
  146. # compressed_array.seek(0)
  147. # pred = compressed_array.read()
  148. # log("numpy compress " + str(time.time()-_time))
  149. # 写入redis
  150. gpu_time = round(float(time.time()-start_time1), 2)
  151. finish_time = round(float(time.time()-start_time), 2)
  152. redis_db.hset("consumer_otr", _uuid, pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
  153. log("to consumer_otr " + str(time.time()-_time))
  154. log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
  155. except Exception as e:
  156. traceback.print_exc()
  157. @app.route('/otr', methods=['POST'])
  158. def _otr_flask_sm():
  159. start_time = time.time()
  160. log("into _otr")
  161. _global._init()
  162. _global.update({"port": globals().get("port")})
  163. log("into _otr -> _global " + str(time.time()-start_time))
  164. start_time = time.time()
  165. try:
  166. if not request.form:
  167. log("otr no data!")
  168. return json.dumps({"list_line": str([-9])})
  169. log("judge request.form " + str(time.time()-start_time))
  170. _time = time.time()
  171. result = json.loads(request.form.get("data"))
  172. model_type = result.get("model_type")
  173. args = result.get("args")
  174. _md5 = result.get("md5")
  175. sm_name = result.get("sm_name")
  176. sm_shape = result.get("sm_shape")
  177. sm_dtype = result.get("sm_dtype")
  178. sm_dtype = get_np_type(sm_dtype)
  179. _global.update({"md5": _md5})
  180. log("read data " + str(time.time()-_time))
  181. # 读取共享内存
  182. _time = time.time()
  183. if sm_name:
  184. inputs = from_share_memory(sm_name, sm_shape, sm_dtype)
  185. else:
  186. log("from_share_memory failed!")
  187. raise Exception
  188. log("data from share memory " + sm_name + " " + str(time.time()-_time))
  189. # 获取模型
  190. model = globals().get(model_type)
  191. if model is None:
  192. start_time1 = time.time()
  193. log("=== init " + model_type + " model ===")
  194. model = OtrModels().get_model()
  195. globals().update({model_type: model})
  196. log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  197. # 运行
  198. _time = time.time()
  199. with lock:
  200. pred = model.predict(inputs)
  201. pred = pred[0]
  202. _shape = pred.shape
  203. _dtype = str(pred.dtype)
  204. log("pred.shape " + str(pred.shape))
  205. gpu_time = round(float(time.time()-_time), 2)
  206. # 判断前一个读取完
  207. _time = time.time()
  208. while True:
  209. shm = globals().get("shm")
  210. if shm is None:
  211. break
  212. last_shape = globals().get("last_shape")
  213. sm_data = np.ndarray(last_shape, dtype=sm_dtype, buffer=shm.buf)
  214. if (sm_data == np.zeros(last_shape)).all():
  215. try:
  216. _time1 = time.time()
  217. shm.close()
  218. shm.unlink()
  219. log("release share memory " + str(time.time()-_time1))
  220. except FileNotFoundError:
  221. log("share memory " + shm.name + " not exists!")
  222. break
  223. log("wait for share memory being read " + str(time.time()-_time))
  224. # 数据放入共享内存
  225. _time = time.time()
  226. shm = to_share_memory(pred)
  227. globals().update({"shm": shm})
  228. globals().update({"last_shape": _shape})
  229. log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
  230. finish_time = round(float(time.time()-start_time), 2)
  231. log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
  232. return json.dumps({"gpu_time": gpu_time, "elapse": finish_time,
  233. "sm_name": shm.name, "sm_shape": _shape, "sm_dtype": _dtype})
  234. except Exception as e:
  235. finish_time = round(float(time.time()-start_time), 2)
  236. traceback.print_exc()
  237. return json.dumps({"gpu_time": 0., "elapse": finish_time,
  238. "sm_name": None, "sm_shape": None, "sm_dtype": None})
  239. def _otr():
  240. start_time = time.time()
  241. log("into _ocr")
  242. _global._init()
  243. _global.update({"port": globals().get("port")})
  244. log("into _ocr -> _global " + str(time.time()-start_time))
  245. start_time = time.time()
  246. try:
  247. # 循环判断是否有新数据需处理
  248. while True:
  249. try:
  250. full_sm_list = get_share_memory_list(sm_list_name="sml_otr_"+str(globals().get("port")))
  251. except FileNotFoundError:
  252. full_sm_list = get_share_memory_list(sm_list_name="sml_otr_"+str(globals().get("port")), list_size=10)
  253. try:
  254. if full_sm_list[0] == "1" and full_sm_list[-1] == "1":
  255. log("empty_sm_list[0] " + full_sm_list[0])
  256. log("empty_sm_list[-1] " + full_sm_list[-1])
  257. log("empty_sm_list[1] " + full_sm_list[1])
  258. log("wait for " + str(time.time()-start_time))
  259. break
  260. except ValueError:
  261. continue
  262. start_time = time.time()
  263. _time = time.time()
  264. _md5 = full_sm_list[1]
  265. model_type = full_sm_list[2]
  266. sm_name = full_sm_list[5]
  267. sm_shape = full_sm_list[6]
  268. sm_shape = eval(sm_shape)
  269. sm_dtype = full_sm_list[7]
  270. sm_dtype = get_np_type(sm_dtype)
  271. _global.update({"md5": _md5})
  272. log("read data " + str(time.time()-_time))
  273. # 读取共享内存
  274. _time = time.time()
  275. if sm_name:
  276. inputs = from_share_memory(sm_name, sm_shape, sm_dtype)
  277. else:
  278. log("from_share_memory failed!")
  279. raise Exception
  280. log("data from share memory " + sm_name + " " + str(time.time()-_time))
  281. # 获取模型
  282. model = globals().get(model_type)
  283. if model is None:
  284. start_time1 = time.time()
  285. log("=== init " + model_type + " model ===")
  286. model = OtrModels().get_model()
  287. globals().update({model_type: model})
  288. log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
  289. # 运行
  290. _time = time.time()
  291. with lock:
  292. pred = model.predict(inputs)
  293. preds = pred[0]
  294. log("preds.shape " + str(preds.shape))
  295. gpu_time = round(float(time.time()-_time), 2)
  296. # 数据放入共享内存
  297. _time = time.time()
  298. # 先释放之前的同名share memory
  299. release_share_memory(get_share_memory(sm_name))
  300. # 写入共享内存
  301. shm = to_share_memory(preds, sm_name)
  302. full_sm_list[5] = shm.name
  303. full_sm_list[6] = str(preds.shape)
  304. full_sm_list[7] = str(preds.dtype)
  305. full_sm_list[8] = str(gpu_time)
  306. full_sm_list[-1] = "0"
  307. log("data into share memory " + str(shm.name) + " " + str(time.time()-_time))
  308. close_share_memory_list(full_sm_list)
  309. finish_time = round(float(time.time()-start_time), 2)
  310. log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
  311. except Exception as e:
  312. finish_time = round(float(time.time()-start_time), 2)
  313. traceback.print_exc()
  314. raise
  315. class OtrModels:
  316. def __init__(self):
  317. # python文件所在目录
  318. _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
  319. model_path = _dir + "/models/table-line.h5"
  320. self.otr_model = table_net((None, None, 3), 2)
  321. self.otr_model.load_weights(model_path)
  322. def get_model(self):
  323. return self.otr_model
  324. if __name__ == '__main__':
  325. if len(sys.argv) == 2:
  326. port = int(sys.argv[1])
  327. using_gpu_index = 0
  328. elif len(sys.argv) == 3:
  329. port = int(sys.argv[1])
  330. using_gpu_index = int(sys.argv[2])
  331. else:
  332. port = 18000
  333. using_gpu_index = 0
  334. # _global._init()
  335. # _global.update({"port": str(port)})
  336. globals().update({"port": str(port)})
  337. # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
  338. # app.run()
  339. # log("OTR running "+str(port))
  340. while True:
  341. _otr()