run_extract_server.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jun 1 18:03:03 2018
  4. @author: DONG
  5. """
  6. import sys
  7. import os
  8. from flask import Flask, jsonify
  9. from flask import abort
  10. from flask import request
  11. sys.path.append(os.path.dirname(__file__)+"/..")
  12. os.environ["KERAS_BACKEND"] = "tensorflow"
  13. app = Flask(__name__)
  14. app.config['JSON_AS_ASCII'] = False
  15. limit_num = "4"
  16. os.environ["OMP_NUM_THREADS"] = limit_num # 1为一个核,设置为5的时候,系统显示用了10个核,不太清楚之间的具体数量关系
  17. os.environ["OMP_NUM_THREADS"] = limit_num # export OMP_NUM_THREADS=1
  18. os.environ["OPENBLAS_NUM_THREADS"] = limit_num # export OPENBLAS_NUM_THREADS=1
  19. os.environ["MKL_NUM_THREADS"] = limit_num # export MKL_NUM_THREADS=1
  20. os.environ["VECLIB_MAXIMUM_THREADS"] = limit_num # export VECLIB_MAXIMUM_THREADS=1
  21. os.environ["NUMEXPR_NUM_THREADS"] = limit_num # export NUMEXPR_NUM_THREADS=1
  22. import time
  23. import uuid
  24. import numpy as np
  25. import ctypes
  26. import inspect
  27. from threading import Thread
  28. import traceback
  29. import json
  30. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  31. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  32. sys.path.append(os.path.abspath("."))
  33. #自定义jsonEncoder
  34. class MyEncoder(json.JSONEncoder):
  35. def default(self, obj):
  36. if isinstance(obj, np.ndarray):
  37. return obj.tolist()
  38. elif isinstance(obj, bytes):
  39. return str(obj, encoding='utf-8')
  40. elif isinstance(obj, (np.float_, np.float16, np.float32,
  41. np.float64)):
  42. return float(obj)
  43. return json.JSONEncoder.default(self, obj)
  44. def _async_raise(tid, exctype):
  45. """raises the exception, performs cleanup if needed"""
  46. tid = ctypes.c_long(tid)
  47. if not inspect.isclass(exctype):
  48. exctype = type(exctype)
  49. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  50. if res == 0:
  51. raise ValueError("invalid thread id")
  52. elif res != 1:
  53. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  54. raise SystemError("PyThreadState_SetAsyncExc failed")
  55. def stop_thread(thread):
  56. _async_raise(thread.ident, SystemExit)
  57. def run_thread(data,list_result):
  58. # data = data.decode("utf8")
  59. # data = json.loads(data,encoding="utf8")
  60. k = str(uuid.uuid4())
  61. cost_time = dict()
  62. _doc_id = data.get("doc_id","")
  63. _title = data.get("title","")
  64. _content = data.get("content","")
  65. _page_time = data.get("page_time","")
  66. data_res = ""
  67. web_source_no = data.get("web_source_no","")
  68. web_source_name = data.get("web_source_name","")
  69. original_docchannel = data.get("original_docchannel","")
  70. page_attachments = data.get("page_attachments","")
  71. # print("web_source_name:",web_source_name)
  72. is_fail = False
  73. try:
  74. if _content!="":
  75. data_res = predict(_doc_id,_content,_title,_page_time,web_source_no=web_source_no,web_source_name=web_source_name,original_docchannel=original_docchannel,page_attachments=page_attachments)
  76. else:
  77. data_res = json.dumps({"success":False,"msg":"content not passed"})
  78. # is_fail = True
  79. except Exception as e:
  80. traceback.print_exc()
  81. data_res = json.dumps({"success":False,"msg":str(e)})
  82. is_fail = True
  83. # 以json形式返回结果
  84. #_resp = json.dumps(data_res,cls=MyEncoder)
  85. #log(str(data["flag"])+str(data))
  86. # log("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  87. list_result.append(data_res)
  88. if is_fail:
  89. list_result.append(is_fail)
  90. @app.route("/test",methods=['POST'])
  91. def test():
  92. from BiddingKG.dl.common.Utils import log
  93. from BiddingKG.dl.interface.extract import predict
  94. global predict,log
  95. _time = time.time()
  96. a = request.form.get("content")
  97. log("get form takes %.2fs"%(time.time()-_time))
  98. return json.dumps(sys.getsizeof(request.form)),201
  99. @app.route('/content_extract', methods=['POST'])
  100. def text_predict():
  101. from BiddingKG.dl.common.Utils import log
  102. from BiddingKG.dl.interface.extract import predict
  103. global predict,log
  104. _time = time.time()
  105. data = request.json
  106. status_code = 200
  107. list_result = []
  108. _timeout = data.get("timeout",400)
  109. log("get data cost:%.2fs"%((time.time()-_time)))
  110. t = Thread(target=run_thread,args=(data,list_result))
  111. start_time = time.time()
  112. t.start()
  113. t.join(_timeout)
  114. if t.is_alive():
  115. stop_thread(t)
  116. status_code = 302#超时被kill
  117. data_res = json.dumps({"success":False,"msg":"timeout"})
  118. else:
  119. # status_code += int((time.time()-start_time)%10+1)
  120. status_code = 201
  121. data_res = list_result[0]
  122. if len(list_result)>1 and list_result[1] ==True:
  123. status_code = 500
  124. _resp = data_res
  125. # _resp = predict(doc_id=_doc_id,text=_content,title=_title,page_time=_page_time)
  126. return _resp,status_code
  127. def getPort(argv):
  128. port = 15030
  129. print(argv)
  130. for item in argv:
  131. _l = str(item).split("port=")
  132. if len(_l)>1:
  133. port = int(_l[-1])
  134. break
  135. return port
  136. def getWorkers(argv):
  137. worker = 15
  138. for item in argv:
  139. _l = str(item).split("worker=")
  140. if len(_l)>1:
  141. worker = int(_l[-1])
  142. break
  143. return worker
  144. def start_with_tornado(port,process_num):
  145. from tornado.wsgi import WSGIContainer
  146. from tornado.httpserver import HTTPServer
  147. from tornado.ioloop import IOLoop
  148. print("import ")
  149. http_server = HTTPServer(WSGIContainer(app))
  150. # http_server.listen(port) #shortcut for bind and start
  151. http_server.bind(port)
  152. http_server.start(process_num)
  153. IOLoop.instance().start()
  154. def start_with_flask():
  155. port = getPort(argv=sys.argv)
  156. app.run(host='0.0.0.0', port=port, threaded=True, debug=False)
  157. log("ContentExtractor running")
  158. # app.run()
  159. if __name__ == '__main__':
  160. port = getPort(argv=sys.argv)
  161. workers = getWorkers(argv=sys.argv)
  162. start_with_tornado(port,workers)
  163. pass