run_extract_server.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jun 1 18:03:03 2018
  4. @author: DONG
  5. """
  6. import sys
  7. import os
  8. from flask import Flask, jsonify
  9. from flask import abort
  10. from flask import request
  11. sys.path.append(os.path.dirname(__file__)+"/..")
  12. os.environ["KERAS_BACKEND"] = "tensorflow"
  13. app = Flask(__name__)
  14. app.config['JSON_AS_ASCII'] = False
  15. os.environ["OMP_NUM_THREADS"] = "4" # 1为一个核,设置为5的时候,系统显示用了10个核,不太清楚之间的具体数量关系
  16. import time
  17. import uuid
  18. from BiddingKG.dl.common.Utils import log
  19. from BiddingKG.dl.interface.extract import predict
  20. import numpy as np
  21. import ctypes
  22. import inspect
  23. from threading import Thread
  24. import traceback
  25. import json
  26. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  27. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  28. sys.path.append(os.path.abspath("."))
  29. #自定义jsonEncoder
  30. class MyEncoder(json.JSONEncoder):
  31. def default(self, obj):
  32. if isinstance(obj, np.ndarray):
  33. return obj.tolist()
  34. elif isinstance(obj, bytes):
  35. return str(obj, encoding='utf-8')
  36. elif isinstance(obj, (np.float_, np.float16, np.float32,
  37. np.float64)):
  38. return float(obj)
  39. return json.JSONEncoder.default(self, obj)
  40. def _async_raise(tid, exctype):
  41. """raises the exception, performs cleanup if needed"""
  42. tid = ctypes.c_long(tid)
  43. if not inspect.isclass(exctype):
  44. exctype = type(exctype)
  45. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  46. if res == 0:
  47. raise ValueError("invalid thread id")
  48. elif res != 1:
  49. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  50. raise SystemError("PyThreadState_SetAsyncExc failed")
  51. def stop_thread(thread):
  52. _async_raise(thread.ident, SystemExit)
  53. def run_thread(data,list_result):
  54. # data = data.decode("utf8")
  55. # data = json.loads(data,encoding="utf8")
  56. k = str(uuid.uuid4())
  57. cost_time = dict()
  58. _doc_id = data.get("doc_id","")
  59. _title = data.get("title","")
  60. _content = data.get("content","")
  61. _page_time = data.get("page_time","")
  62. data_res = ""
  63. web_source_no = data.get("web_source_no","")
  64. original_docchannel = data.get("original_docchannel","")
  65. is_fail = False
  66. try:
  67. if _content!="":
  68. data_res = predict(_doc_id,_content,_title,_page_time,web_source_no,original_docchannel)
  69. else:
  70. data_res = json.dumps({"success":False,"msg":"content not passed"})
  71. # is_fail = True
  72. except Exception as e:
  73. traceback.print_exc()
  74. data_res = json.dumps({"success":False,"msg":str(e)})
  75. is_fail = True
  76. # 以json形式返回结果
  77. #_resp = json.dumps(data_res,cls=MyEncoder)
  78. #log(str(data["flag"])+str(data))
  79. log("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  80. list_result.append(data_res)
  81. if is_fail:
  82. list_result.append(is_fail)
  83. @app.route("/test",methods=['POST'])
  84. def test():
  85. _time = time.time()
  86. a = request.form.get("content")
  87. log("get form takes %.2fs"%(time.time()-_time))
  88. return json.dumps(sys.getsizeof(request.form)),201
  89. @app.route('/content_extract', methods=['POST'])
  90. def text_predict():
  91. _time = time.time()
  92. data = request.json
  93. status_code = 200
  94. list_result = []
  95. _timeout = data.get("timeout",400)
  96. log("get data cost:%.2fs"%((time.time()-_time)))
  97. t = Thread(target=run_thread,args=(data,list_result))
  98. start_time = time.time()
  99. t.start()
  100. t.join(_timeout)
  101. if t.is_alive():
  102. stop_thread(t)
  103. status_code = 302#超时被kill
  104. data_res = json.dumps({"success":False,"msg":"timeout"})
  105. else:
  106. # status_code += int((time.time()-start_time)%10+1)
  107. status_code = 201
  108. data_res = list_result[0]
  109. if len(list_result)>1 and list_result[1] ==True:
  110. status_code = 500
  111. _resp = data_res
  112. # _resp = predict(doc_id=_doc_id,text=_content,title=_title,page_time=_page_time)
  113. return _resp,status_code
  114. def getPort(argv):
  115. port = 15030
  116. for item in argv:
  117. _l = str(item).split("port=")
  118. if len(_l)>1:
  119. port = int(_l[-1])
  120. break
  121. return port
  122. if __name__ == '__main__':
  123. port = getPort(argv=sys.argv)
  124. app.run(host='0.0.0.0', port=port, threaded=True, debug=False)
  125. log("ContentExtractor running")
  126. # app.run()