| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- # encoding=utf8
- import base64
- import io
- import json
- import os
- import pickle
- import threading
- import traceback
- # os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
- # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
- import redis
- import tensorflow as tf
- try:
- gpus = tf.config.list_physical_devices('GPU')
- if len(gpus) > 0:
- tf.config.experimental.set_virtual_device_configuration(
- gpus[0],
- [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2000)])
- except:
- traceback.print_exc()
- pass
- import sys
- sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../")
- import time
- import logging
- # from table_line import *
- import cv2
- import numpy as np
- from flask import Flask, request
- from format_convert.utils import request_post, judge_error_code, get_intranet_ip, log, get_md5_from_bytes, get_platform
- from otr.table_line import table_net, table_line, table_preprocess, table_postprocess
- from format_convert import _global
- # 接口配置
- app = Flask(__name__)
- lock = threading.RLock()
- redis_db = redis.StrictRedis(host='127.0.0.1', port='6379',
- db=1, password='bidi123456', health_check_interval=300)
- # @app.route('/otr', methods=['POST'])
- def _otr_flask():
- start_time = time.time()
- log("into _otr")
- _global._init()
- _global.update({"port": globals().get("port")})
- log("into _otr -> _global " + str(time.time()-start_time))
- start_time = time.time()
- try:
- if not request.form:
- log("otr no data!")
- return json.dumps({"list_line": str([-9])})
- log("judge request.form " + str(time.time()-start_time))
- start_time1 = time.time()
- # 反序列化
- result = pickle.loads(base64.b64decode(request.form.get("data")))
- inputs = result.get("inputs")
- # 解压numpy
- decompressed_array = io.BytesIO()
- decompressed_array.write(inputs)
- decompressed_array.seek(0)
- inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
- log("inputs.shape" + str(inputs.shape))
- predictor_type = result.get("predictor_type")
- model_type = result.get("model_type")
- _md5 = result.get("md5")
- _global.update({"md5": _md5})
- log("read data " + str(time.time()-start_time1))
- # 获取模型
- model = globals().get(model_type)
- if model is None:
- start_time1 = time.time()
- log("=== init " + model_type + " model ===")
- model = OtrModels().get_model()
- globals().update({model_type: model})
- log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
- # 运行
- with lock:
- start_time1 = time.time()
- pred = model.predict(inputs)
- pred = pred[0]
- log("pred.shape " + str(pred.shape))
- # 压缩numpy
- compressed_array = io.BytesIO()
- np.savez_compressed(compressed_array, pred)
- compressed_array.seek(0)
- pred = compressed_array.read()
- gpu_time = round(float(time.time()-start_time1), 2)
- finish_time = round(float(time.time()-start_time), 2)
- log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
- return base64.b64encode(pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
- except Exception as e:
- finish_time = round(float(time.time()-start_time), 2)
- traceback.print_exc()
- return base64.b64encode(pickle.dumps({"preds": None, "gpu_time": 0., "elapse": finish_time}))
- def _otr():
- start_time = time.time()
- log("into _otr")
- _global._init()
- _global.update({"port": globals().get("port")})
- log("into _otr -> _global " + str(time.time()-start_time))
- while True:
- start_time = time.time()
- try:
- if redis_db.llen("producer_otr") == 0:
- continue
- log("judge llen " + str(time.time()-start_time))
- _time = time.time()
- result = redis_db.lpop("producer_otr")
- if result is None:
- continue
- result = pickle.loads(result)
- log("from producer_otr time " + str(time.time() - _time))
- _time = time.time()
- inputs = result.get("inputs")
- # # 解压numpy
- # decompressed_array = io.BytesIO()
- # decompressed_array.write(inputs)
- # decompressed_array.seek(0)
- # inputs = np.load(decompressed_array, allow_pickle=True)['arr_0']
- # log("inputs.shape " + str(inputs.shape))
- # log("numpy decompress " + str(time.time()-_time))
- predictor_type = result.get("predictor_type")
- _uuid = result.get("uuid")
- model_type = result.get("model_type")
- _md5 = result.get("md5")
- _global.update({"md5": _md5})
- log("read data " + str(time.time()-_time))
- # 获取模型
- model = globals().get(model_type)
- if model is None:
- start_time1 = time.time()
- log("=== init " + model_type + " model ===")
- model = OtrModels().get_model()
- globals().update({model_type: model})
- log("=== init " + model_type + " model " + str(round(time.time()-start_time1, 2)) + " ===")
- # 运行
- start_time1 = time.time()
- pred = model.predict(inputs)
- pred = pred[0]
- log("pred.shape " + str(pred.shape))
- # # 压缩numpy
- # _time = time.time()
- # compressed_array = io.BytesIO()
- # np.savez_compressed(compressed_array, pred)
- # compressed_array.seek(0)
- # pred = compressed_array.read()
- # log("numpy compress " + str(time.time()-_time))
- # 写入redis
- gpu_time = round(float(time.time()-start_time1), 2)
- finish_time = round(float(time.time()-start_time), 2)
- redis_db.hset("consumer_otr", _uuid, pickle.dumps({"preds": pred, "gpu_time": gpu_time, "elapse": finish_time}))
- log("to consumer_otr " + str(time.time()-_time))
- log("otr model predict time " + str(gpu_time) + " " + str(finish_time))
- except Exception as e:
- traceback.print_exc()
- class OtrModels:
- def __init__(self):
- # python文件所在目录
- _dir = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
- model_path = _dir + "/models/table-line.h5"
- self.otr_model = table_net((None, None, 3), 2)
- self.otr_model.load_weights(model_path)
- def get_model(self):
- return self.otr_model
- if __name__ == '__main__':
- if len(sys.argv) == 2:
- port = int(sys.argv[1])
- using_gpu_index = 0
- elif len(sys.argv) == 3:
- port = int(sys.argv[1])
- using_gpu_index = int(sys.argv[2])
- else:
- port = 18000
- using_gpu_index = 0
- # _global._init()
- # _global.update({"port": str(port)})
- # globals().update({"port": str(port)})
- # app.run(host='0.0.0.0', port=port, processes=1, threaded=False, debug=False)
- # app.run()
- # log("OTR running "+str(port))
- _otr()
|