run_extract_server.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jun 1 18:03:03 2018
  4. @author: DONG
  5. """
  6. import sys
  7. import os
  8. from flask import Flask, jsonify
  9. from flask import abort
  10. from flask import request
  11. sys.path.append(os.path.dirname(__file__)+"/..")
  12. os.environ["KERAS_BACKEND"] = "tensorflow"
  13. app = Flask(__name__)
  14. app.config['JSON_AS_ASCII'] = False
  15. limit_num = "4"
  16. os.environ["OMP_NUM_THREADS"] = limit_num # 1为一个核,设置为5的时候,系统显示用了10个核,不太清楚之间的具体数量关系
  17. os.environ["OMP_NUM_THREADS"] = limit_num # export OMP_NUM_THREADS=1
  18. os.environ["OPENBLAS_NUM_THREADS"] = limit_num # export OPENBLAS_NUM_THREADS=1
  19. os.environ["MKL_NUM_THREADS"] = limit_num # export MKL_NUM_THREADS=1
  20. os.environ["VECLIB_MAXIMUM_THREADS"] = limit_num # export VECLIB_MAXIMUM_THREADS=1
  21. os.environ["NUMEXPR_NUM_THREADS"] = limit_num # export NUMEXPR_NUM_THREADS=1
  22. import time
  23. import uuid
  24. from BiddingKG.dl.common.Utils import log
  25. from BiddingKG.dl.interface.extract import predict
  26. import numpy as np
  27. import ctypes
  28. import inspect
  29. from threading import Thread
  30. import traceback
  31. import json
  32. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  33. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  34. sys.path.append(os.path.abspath("."))
  35. #自定义jsonEncoder
  36. class MyEncoder(json.JSONEncoder):
  37. def default(self, obj):
  38. if isinstance(obj, np.ndarray):
  39. return obj.tolist()
  40. elif isinstance(obj, bytes):
  41. return str(obj, encoding='utf-8')
  42. elif isinstance(obj, (np.float_, np.float16, np.float32,
  43. np.float64)):
  44. return float(obj)
  45. return json.JSONEncoder.default(self, obj)
  46. def _async_raise(tid, exctype):
  47. """raises the exception, performs cleanup if needed"""
  48. tid = ctypes.c_long(tid)
  49. if not inspect.isclass(exctype):
  50. exctype = type(exctype)
  51. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  52. if res == 0:
  53. raise ValueError("invalid thread id")
  54. elif res != 1:
  55. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  56. raise SystemError("PyThreadState_SetAsyncExc failed")
  57. def stop_thread(thread):
  58. _async_raise(thread.ident, SystemExit)
  59. def run_thread(data,list_result):
  60. # data = data.decode("utf8")
  61. # data = json.loads(data,encoding="utf8")
  62. k = str(uuid.uuid4())
  63. cost_time = dict()
  64. _doc_id = data.get("doc_id","")
  65. _title = data.get("title","")
  66. _content = data.get("content","")
  67. _page_time = data.get("page_time","")
  68. data_res = ""
  69. web_source_no = data.get("web_source_no","")
  70. original_docchannel = data.get("original_docchannel","")
  71. is_fail = False
  72. try:
  73. if _content!="":
  74. data_res = predict(_doc_id,_content,_title,_page_time,web_source_no,original_docchannel)
  75. else:
  76. data_res = json.dumps({"success":False,"msg":"content not passed"})
  77. # is_fail = True
  78. except Exception as e:
  79. traceback.print_exc()
  80. data_res = json.dumps({"success":False,"msg":str(e)})
  81. is_fail = True
  82. # 以json形式返回结果
  83. #_resp = json.dumps(data_res,cls=MyEncoder)
  84. #log(str(data["flag"])+str(data))
  85. log("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  86. list_result.append(data_res)
  87. if is_fail:
  88. list_result.append(is_fail)
  89. @app.route("/test",methods=['POST'])
  90. def test():
  91. _time = time.time()
  92. a = request.form.get("content")
  93. log("get form takes %.2fs"%(time.time()-_time))
  94. return json.dumps(sys.getsizeof(request.form)),201
  95. @app.route('/content_extract', methods=['POST'])
  96. def text_predict():
  97. _time = time.time()
  98. data = request.json
  99. status_code = 200
  100. list_result = []
  101. _timeout = data.get("timeout",400)
  102. log("get data cost:%.2fs"%((time.time()-_time)))
  103. t = Thread(target=run_thread,args=(data,list_result))
  104. start_time = time.time()
  105. t.start()
  106. t.join(_timeout)
  107. if t.is_alive():
  108. stop_thread(t)
  109. status_code = 302#超时被kill
  110. data_res = json.dumps({"success":False,"msg":"timeout"})
  111. else:
  112. # status_code += int((time.time()-start_time)%10+1)
  113. status_code = 201
  114. data_res = list_result[0]
  115. if len(list_result)>1 and list_result[1] ==True:
  116. status_code = 500
  117. _resp = data_res
  118. # _resp = predict(doc_id=_doc_id,text=_content,title=_title,page_time=_page_time)
  119. return _resp,status_code
  120. def getPort(argv):
  121. port = 15030
  122. for item in argv:
  123. _l = str(item).split("port=")
  124. if len(_l)>1:
  125. port = int(_l[-1])
  126. break
  127. return port
  128. if __name__ == '__main__':
  129. port = getPort(argv=sys.argv)
  130. app.run(host='0.0.0.0', port=port, threaded=True, debug=False)
  131. log("ContentExtractor running")
  132. # app.run()