app.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. '''
  2. Created on 2019年12月3日
  3. @author: User
  4. '''
  5. import allspark
  6. import sys
  7. import os
  8. os.environ["KERAS_BACKEND"] = "tensorflow"
  9. import json
  10. import re
  11. import time
  12. import uuid
  13. from BiddingKG.dl.common.Utils import log
  14. from BiddingKG.dl.interface.extract import predict
  15. import numpy as np
  16. import ctypes
  17. import inspect
  18. from threading import Thread
  19. import traceback
  20. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  21. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  22. sys.path.append(os.path.abspath("."))
  23. #自定义jsonEncoder
  24. class MyEncoder(json.JSONEncoder):
  25. def default(self, obj):
  26. if isinstance(obj, np.ndarray):
  27. return obj.tolist()
  28. elif isinstance(obj, bytes):
  29. return str(obj, encoding='utf-8')
  30. elif isinstance(obj, (np.float_, np.float16, np.float32,
  31. np.float64)):
  32. return float(obj)
  33. return json.JSONEncoder.default(self, obj)
  34. def _async_raise(tid, exctype):
  35. """raises the exception, performs cleanup if needed"""
  36. tid = ctypes.c_long(tid)
  37. if not inspect.isclass(exctype):
  38. exctype = type(exctype)
  39. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  40. if res == 0:
  41. raise ValueError("invalid thread id")
  42. elif res != 1:
  43. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  44. raise SystemError("PyThreadState_SetAsyncExc failed")
  45. def stop_thread(thread):
  46. _async_raise(thread.ident, SystemExit)
  47. class MyProcessor(allspark.BaseProcessor):
  48. """ MyProcessor is a example
  49. you can send mesage like this to predict
  50. curl -v http://127.0.0.1:8080/api/predict/service_name -d '2 105'
  51. """
  52. def run_thread(self,data,list_result):
  53. # data = data.decode("utf8")
  54. # data = json.loads(data,encoding="utf8")
  55. k = str(uuid.uuid4())
  56. cost_time = dict()
  57. if "doc_id" in data:
  58. _doc_id = data['doc_id']
  59. else:
  60. _doc_id = ""
  61. if "title" in data:
  62. _title = data["title"]
  63. else:
  64. _title = ""
  65. data_res = ""
  66. try:
  67. if "content" in data:
  68. log("get request of doc_id:%s"%(_doc_id))
  69. k = str(uuid.uuid4())
  70. cost_time = dict()
  71. content = data['content']
  72. start_time = time.time()
  73. list_articles,list_sentences,list_entitys,_cost_time = Preprocessing.get_preprocessed([[k,content,"",_doc_id,_title]],useselffool=True)
  74. log("get preprocessed done of doc_id%s"%(_doc_id))
  75. cost_time["preprocess"] = time.time()-start_time
  76. cost_time.update(_cost_time)
  77. '''
  78. for articles in list_articles:
  79. print(articles.content)
  80. '''
  81. start_time = time.time()
  82. codeName = self.codeNamePredict.predict(list_sentences,list_entitys=list_entitys)
  83. log("get codename done of doc_id%s"%(_doc_id))
  84. cost_time["codename"] = time.time()-start_time
  85. start_time = time.time()
  86. self.premPredict.predict(list_sentences,list_entitys)
  87. self.premPredict.predict(list_sentences,list_entitys)
  88. log("get prem done of doc_id%s"%(_doc_id))
  89. cost_time["prem"] = time.time()-start_time
  90. start_time = time.time()
  91. self.roleRulePredict.predict(list_articles,list_sentences, list_entitys,codeName)
  92. # self.roleRulePredict.predict(list_articles,list_sentences, list_entitys,codeName)
  93. cost_time["rule"] = time.time()-start_time
  94. start_time = time.time()
  95. self.epcPredict.predict(list_sentences,list_entitys)
  96. log("get epc done of doc_id%s"%(_doc_id))
  97. cost_time["person"] = time.time()-start_time
  98. start_time = time.time()
  99. entityLink.link_entitys(list_entitys)
  100. '''
  101. for list_entity in list_entitys:
  102. for _entity in list_entity:
  103. for _ent in _entity.linked_entitys:
  104. print(_entity.entity_text,_ent.entity_text)
  105. '''
  106. prem = getAttributes.getPREMs(list_sentences,list_entitys,list_articles)
  107. log("get attributes done of doc_id%s"%(_doc_id))
  108. cost_time["attrs"] = time.time()-start_time
  109. '''
  110. for entitys in list_entitys:
  111. for entity in entitys:
  112. print(entity.entity_text,entity.entity_type,entity.sentence_index,entity.begin_index,entity.label,entity.values)
  113. '''
  114. #print(prem)
  115. data_res = predict(docid)
  116. data_res["cost_time"] = cost_time
  117. data_res["success"] = True
  118. #return json.dumps(Preprocessing.union_result(codeName, prem)[0][1],cls=MyEncoder,sort_keys=True,indent=4,ensure_ascii=False)
  119. else:
  120. data_res = {"success":False,"msg":"content not passed"}
  121. except Exception as e:
  122. traceback.print_exc()
  123. data_res = {"success":False,"msg":str(e)}
  124. # 以json形式返回结果
  125. #_resp = json.dumps(data_res,cls=MyEncoder)
  126. #log(str(data["flag"])+str(data))
  127. log("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  128. list_result.append(data_res)
  129. def initialize(self):
  130. """ load module, executed once at the start of the service
  131. do service intialization and load models in this function.
  132. """'''
  133. '''
  134. self.timeout = 60
  135. self.status_types = 5
  136. self.timeOfType = self.timeout//self.status_types
  137. def pre_proccess(self, data):
  138. """ data format pre process
  139. """
  140. x, y = data.split(b' ')
  141. return int(x), int(y)
  142. def post_process(self, data):
  143. """ proccess after process
  144. """
  145. return bytes(data, encoding='utf8')
  146. def process(self, data):
  147. """ process the request data
  148. """
  149. data = data.decode("utf8")
  150. data = json.loads(data,encoding="utf8")
  151. _doc_id = data.get("doc_id","")
  152. _title = data.get("title","")
  153. _content = data.get("content","")
  154. status_code = 200
  155. # if "timeout" in data:
  156. # _timeout = data["timeout"]
  157. list_result = []
  158. # t = Thread(target=self.run_thread,args=(data,list_result))
  159. # start_time = time.time()
  160. # t.start()
  161. # t.join(_timeout)
  162. # if t.is_alive():
  163. # stop_thread(t)
  164. # status_code = 302#超时被kill
  165. # data_res = {"success":False,"msg":"timeout"}
  166. # else:
  167. # status_code += int((time.time()-start_time)//self.timeOfType+1)
  168. # data_res = list_result[0]
  169. # _resp = json.dumps(data_res,cls=MyEncoder)
  170. _resp = predict(doc_id=_doc_id,text=_content,title=_title)
  171. return self.post_process(_resp),status_code
  172. if __name__ == '__main__':
  173. # paramter worker_threads indicates concurrency of processing
  174. #本地运行
  175. allspark.default_properties().put("rpc.keepalive", 60000)
  176. runner = MyProcessor(worker_threads=5,worker_processes=1,endpoint="0.0.0.0:15030")
  177. #PAI平台运行
  178. # runner = MyProcessor()
  179. runner.run()