app.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. '''
  2. Created on 2019年12月3日
  3. @author: User
  4. '''
  5. import allspark
  6. import sys
  7. import os
  8. sys.path.append(os.path.dirname(__file__)+"/..")
  9. os.environ["KERAS_BACKEND"] = "tensorflow"
  10. import json
  11. import re
  12. import time
  13. import uuid
  14. from BiddingKG.dl.common.Utils import log
  15. from BiddingKG.dl.interface.extract import predict
  16. import numpy as np
  17. import ctypes
  18. import inspect
  19. from threading import Thread
  20. import traceback
  21. # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  22. # os.environ["CUDA_VISIBLE_DEVICES"] = ""
  23. sys.path.append(os.path.abspath("."))
  24. #自定义jsonEncoder
  25. class MyEncoder(json.JSONEncoder):
  26. def default(self, obj):
  27. if isinstance(obj, np.ndarray):
  28. return obj.tolist()
  29. elif isinstance(obj, bytes):
  30. return str(obj, encoding='utf-8')
  31. elif isinstance(obj, (np.float_, np.float16, np.float32,
  32. np.float64)):
  33. return float(obj)
  34. return json.JSONEncoder.default(self, obj)
  35. def _async_raise(tid, exctype):
  36. """raises the exception, performs cleanup if needed"""
  37. tid = ctypes.c_long(tid)
  38. if not inspect.isclass(exctype):
  39. exctype = type(exctype)
  40. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  41. if res == 0:
  42. raise ValueError("invalid thread id")
  43. elif res != 1:
  44. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  45. raise SystemError("PyThreadState_SetAsyncExc failed")
  46. def stop_thread(thread):
  47. _async_raise(thread.ident, SystemExit)
  48. class MyProcessor(allspark.BaseProcessor):
  49. """ MyProcessor is a example
  50. you can send mesage like this to predict
  51. curl -v http://127.0.0.1:8080/api/predict/service_name -d '2 105'
  52. """
  53. def run_thread(self,data,list_result):
  54. # data = data.decode("utf8")
  55. # data = json.loads(data,encoding="utf8")
  56. k = str(uuid.uuid4())
  57. cost_time = dict()
  58. _doc_id = data.get("doc_id","")
  59. _title = data.get("title","")
  60. _content = data.get("content","")
  61. _page_time = data.get("page_time","")
  62. data_res = ""
  63. web_source_no = data.get("web_source_no","")
  64. original_docchannel = data.get("original_docchannel","")
  65. try:
  66. if "content" in data:
  67. data_res = predict(_doc_id,_content,_title,_page_time,web_source_no,original_docchannel)
  68. else:
  69. data_res = json.dumps({"success":False,"msg":"content not passed"})
  70. except Exception as e:
  71. traceback.print_exc()
  72. data_res = json.dumps({"success":False,"msg":str(e)})
  73. # 以json形式返回结果
  74. #_resp = json.dumps(data_res,cls=MyEncoder)
  75. #log(str(data["flag"])+str(data))
  76. log("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  77. list_result.append(data_res)
  78. def initialize(self):
  79. """ load module, executed once at the start of the service
  80. do service intialization and load models in this function.
  81. """'''
  82. '''
  83. self.timeout = 60
  84. self.status_types = 5
  85. self.timeOfType = self.timeout//self.status_types
  86. def pre_proccess(self, data):
  87. """ data format pre process
  88. """
  89. x, y = data.split(b' ')
  90. return int(x), int(y)
  91. def post_process(self, data):
  92. """ proccess after process
  93. """
  94. return bytes(data, encoding='utf8')
  95. def process(self, data):
  96. """ process the request data
  97. """
  98. data = data.decode("utf8")
  99. data = json.loads(data,encoding="utf8")
  100. status_code = 200
  101. list_result = []
  102. _timeout = data.get("timeout",self.timeout)
  103. t = Thread(target=self.run_thread,args=(data,list_result))
  104. start_time = time.time()
  105. t.start()
  106. t.join(_timeout)
  107. if t.is_alive():
  108. stop_thread(t)
  109. status_code = 302#超时被kill
  110. data_res = json.dumps({"success":False,"msg":"timeout"})
  111. else:
  112. status_code += int((time.time()-start_time)%self.timeOfType+1)
  113. data_res = list_result[0]
  114. _resp = data_res
  115. # _resp = predict(doc_id=_doc_id,text=_content,title=_title,page_time=_page_time)
  116. return self.post_process(_resp),status_code
  117. def getPort(argv):
  118. port = 15030
  119. for item in argv:
  120. _l = str(item).split("port=")
  121. if len(_l)>1:
  122. port = int(_l[-1])
  123. break
  124. return port
  125. if __name__ == '__main__':
  126. # paramter worker_threads indicates concurrency of processing
  127. #本地运行
  128. port = getPort(argv=sys.argv)
  129. allspark.default_properties().put("rpc.keepalive", 250000)
  130. allspark.default_properties().put("rpc.max_queue_size", 100)
  131. log("port==%d"%(port))
  132. #
  133. #
  134. runner = MyProcessor(worker_threads=5,worker_processes=1,endpoint="0.0.0.0:%d"%(port))
  135. #PAI平台运行
  136. # runner = MyProcessor()
  137. runner.run()