run_extract_server.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jun 1 18:03:03 2018
  4. @author: DONG
  5. """
  6. import sys
  7. import os
  8. from flask import Flask, jsonify
  9. from flask import abort
  10. from flask import request
  11. sys.path.append(os.path.dirname(__file__)+"/..")
  12. os.environ["KERAS_BACKEND"] = "tensorflow"
  13. app = Flask(__name__)
  14. app.config['JSON_AS_ASCII'] = False
  15. import time
  16. import uuid
  17. from BiddingKG.dl.common.Utils import log
  18. from BiddingKG.dl.interface.extract import predict
  19. import numpy as np
  20. import ctypes
  21. import inspect
  22. from threading import Thread
  23. import traceback
  24. import json
  25. # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  26. # os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  27. sys.path.append(os.path.abspath("."))
  28. #自定义jsonEncoder
  29. class MyEncoder(json.JSONEncoder):
  30. def default(self, obj):
  31. if isinstance(obj, np.ndarray):
  32. return obj.tolist()
  33. elif isinstance(obj, bytes):
  34. return str(obj, encoding='utf-8')
  35. elif isinstance(obj, (np.float_, np.float16, np.float32,
  36. np.float64)):
  37. return float(obj)
  38. return json.JSONEncoder.default(self, obj)
  39. def _async_raise(tid, exctype):
  40. """raises the exception, performs cleanup if needed"""
  41. tid = ctypes.c_long(tid)
  42. if not inspect.isclass(exctype):
  43. exctype = type(exctype)
  44. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  45. if res == 0:
  46. raise ValueError("invalid thread id")
  47. elif res != 1:
  48. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  49. raise SystemError("PyThreadState_SetAsyncExc failed")
  50. def stop_thread(thread):
  51. _async_raise(thread.ident, SystemExit)
  52. def run_thread(data,list_result):
  53. # data = data.decode("utf8")
  54. # data = json.loads(data,encoding="utf8")
  55. k = str(uuid.uuid4())
  56. cost_time = dict()
  57. _doc_id = data.get("doc_id","")
  58. _title = data.get("title","")
  59. _content = data.get("content","")
  60. _page_time = data.get("page_time","")
  61. data_res = ""
  62. web_source_no = data.get("web_source_no","")
  63. original_docchannel = data.get("original_docchannel","")
  64. try:
  65. if "content" in data:
  66. data_res = predict(_doc_id,_content,_title,_page_time,web_source_no,original_docchannel)
  67. else:
  68. data_res = json.dumps({"success":False,"msg":"content not passed"})
  69. except Exception as e:
  70. traceback.print_exc()
  71. data_res = json.dumps({"success":False,"msg":str(e)})
  72. # 以json形式返回结果
  73. #_resp = json.dumps(data_res,cls=MyEncoder)
  74. #log(str(data["flag"])+str(data))
  75. log("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  76. list_result.append(data_res)
  77. @app.route('/content_extract', methods=['POST'])
  78. def text_predict():
  79. data = request.json
  80. status_code = 200
  81. list_result = []
  82. _timeout = data.get("timeout",400)
  83. t = Thread(target=run_thread,args=(data,list_result))
  84. start_time = time.time()
  85. t.start()
  86. t.join(_timeout)
  87. if t.is_alive():
  88. stop_thread(t)
  89. status_code = 302#超时被kill
  90. data_res = json.dumps({"success":False,"msg":"timeout"})
  91. else:
  92. # status_code += int((time.time()-start_time)%10+1)
  93. status_code = 201
  94. data_res = list_result[0]
  95. _resp = data_res
  96. # _resp = predict(doc_id=_doc_id,text=_content,title=_title,page_time=_page_time)
  97. return _resp,status_code
  98. def getPort(argv):
  99. port = 15030
  100. for item in argv:
  101. _l = str(item).split("port=")
  102. if len(_l)>1:
  103. port = int(_l[-1])
  104. break
  105. return port
  106. if __name__ == '__main__':
  107. port = getPort(argv=sys.argv)
  108. app.run(host='0.0.0.0', port=port, threaded=True, debug=False)
  109. log("ContentExtractor running")
  110. # app.run()