app.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. '''
  2. Created on 2019年12月3日
  3. @author: User
  4. '''
  5. import allspark
  6. import sys
  7. import os
  8. import json
  9. import logging
  10. import time
  11. import uuid
  12. sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
  13. import tensorflow as tf
  14. from text_classifier_pai.main import Text_Classifier
  15. import numpy as np
  16. import ctypes
  17. import inspect
  18. from threading import Thread
  19. import traceback
  20. # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  21. # os.environ["CUDA_VISIBLE_DEVICES"] = ""
  22. sys.path.append(os.path.abspath("."))
  23. #自定义jsonEncoder
  24. class MyEncoder(json.JSONEncoder):
  25. def default(self, obj):
  26. if isinstance(obj, np.ndarray):
  27. return obj.tolist()
  28. elif isinstance(obj, bytes):
  29. return str(obj, encoding='utf-8')
  30. elif isinstance(obj, (np.float_, np.float16, np.float32,
  31. np.float64)):
  32. return float(obj)
  33. return json.JSONEncoder.default(self, obj)
  34. def _async_raise(tid, exctype):
  35. """raises the exception, performs cleanup if needed"""
  36. tid = ctypes.c_long(tid)
  37. if not inspect.isclass(exctype):
  38. exctype = type(exctype)
  39. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  40. if res == 0:
  41. raise ValueError("invalid thread id")
  42. elif res != 1:
  43. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  44. raise SystemError("PyThreadState_SetAsyncExc failed")
  45. def stop_thread(thread):
  46. _async_raise(thread.ident, SystemExit)
  47. class MyProcessor(allspark.BaseProcessor):
  48. """ MyProcessor is a example
  49. you can send mesage like this to predict
  50. curl -v http://127.0.0.1:8080/api/predict/service_name -d '2 105'
  51. """
  52. def run_thread(self,data,list_result):
  53. # data = data.decode("utf8")
  54. # data = json.loads(data,encoding="utf8")
  55. # print('准备处理请求:')
  56. k = str(uuid.uuid4())
  57. cost_time = dict()
  58. if "doc_id" in data:
  59. _doc_id = data['doc_id']
  60. else:
  61. _doc_id = ""
  62. if "title" in data:
  63. _title = data["title"]
  64. else:
  65. _title = ""
  66. data_res = ""
  67. try:
  68. if "content" in data:
  69. # logging.info("get request of doc_id:%s"%(_doc_id))
  70. k = str(uuid.uuid4())
  71. cost_time = dict()
  72. content = data['content']
  73. start_time = time.time()
  74. # print('准备预处理,文章内容:',content[:20])
  75. process, ids = self.classifier.process([content])
  76. # logging.info("get preprocessed done of doc_id%s"%(_doc_id))
  77. # print('预处理完成')
  78. cost_time["preprocess"] = time.time()-start_time
  79. # cost_time.update(_cost_time)
  80. start_time = time.time()
  81. # print('开始预测')
  82. # with self.classifier.sess.graph.as_default():
  83. logits, ids = self.classifier.predict(process, ids)
  84. # print('预测完成')
  85. # logging.info("get predict done of doc_id%s"%(_doc_id))
  86. cost_time["predict"] = time.time()-start_time
  87. start_time = time.time()
  88. # print('准备提取结果')
  89. result = self.classifier.get_results(logits, ids)
  90. class_name = result[0][1] # 得到预测出来的分类名称
  91. subclass, topclass = self.classifier.dic_label[class_name].split(',') # 根据名称查找大类和小类名称
  92. # print('返回类别成功')
  93. # logging.info("get result done of doc_id%s"%(_doc_id))
  94. cost_time["result"] = time.time()-start_time
  95. data_res = {"class":topclass, "class_name":class_name, "subclass":subclass}
  96. data_res["success"] = True
  97. data_res["cost_time"] = cost_time
  98. #print(prem)
  99. # data_res = {'predict':result[0][1]}
  100. # data_res["cost_time"] = cost_time
  101. # data_res["success"] = True
  102. #return json.dumps(Preprocessing.union_result(codeName, prem)[0][1],cls=MyEncoder,sort_keys=True,indent=4,ensure_ascii=False)
  103. else:
  104. data_res = {"success":False,"msg":"content not passed"}
  105. except Exception as e:
  106. traceback.print_exc()
  107. data_res = {"success":False,"msg":str(e)}
  108. logging.error('Exception:%s'%str(e))
  109. # 以json形式返回结果
  110. #_resp = json.dumps(data_res,cls=MyEncoder)
  111. #print(str(data["flag"])+str(data))
  112. logging.info("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  113. list_result.append(data_res)
  114. def initialize(self):
  115. """ load module, executed once at the start of the service
  116. do service intialization and load models in this function.
  117. """'''
  118. '''
  119. self.classifier = Text_Classifier()
  120. self.timeout = 60
  121. self.status_types = 5
  122. self.timeOfType = self.timeout//self.status_types
  123. logging.info('初始化完成, 服务端口15000')
  124. print('初始化完成, 服务端口15000')
  125. def pre_proccess(self, data):
  126. """ data format pre process
  127. """
  128. x, y = data.split(b' ')
  129. return int(x), int(y)
  130. def post_process(self, data):
  131. """ proccess after process
  132. """
  133. return bytes(data, encoding='utf8')
  134. def process(self, data):
  135. """ process the request data
  136. """
  137. try:
  138. data = data.decode("utf8")
  139. data = json.loads(data,encoding="utf8")
  140. _timeout = self.timeout
  141. status_code = 200
  142. if "timeout" in data:
  143. _timeout = data["timeout"]
  144. list_result = []
  145. t = Thread(target=self.run_thread,args=(data,list_result))
  146. start_time = time.time()
  147. t.start()
  148. t.join(_timeout)
  149. if t.is_alive():
  150. stop_thread(t)
  151. status_code = 302#超时被kill
  152. data_res = {"success":False,"msg":"timeout"}
  153. else:
  154. status_code += int((time.time()-start_time)//self.timeOfType+1)
  155. data_res = list_result[0]
  156. _resp = json.dumps(data_res,cls=MyEncoder)
  157. return self.post_process(_resp),status_code
  158. except Exception as e:
  159. pass
  160. return self.post_process(json.dumps({},cls=MyEncoder)),200
  161. def main():
  162. # 创建一个logging对象
  163. logger = logging.getLogger()
  164. # 创建一个文件对象
  165. fh = logging.FileHandler('log_dir/esa_classifier_pai.log', encoding='utf-8')
  166. # 创建一个屏幕对象
  167. sh = logging.StreamHandler()
  168. # 配置显示格式
  169. formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
  170. fh.setFormatter(formatter) # 把格式绑定到两个对象上
  171. sh.setFormatter(formatter)
  172. logger.addHandler(fh) # 将两个句柄绑定到logger
  173. logger.addHandler(sh)
  174. logger.setLevel(10)
  175. fh.setLevel(10)
  176. sh.setLevel(30)
  177. allspark.default_properties().put("rpc.keepalive", 250000)
  178. allspark.default_properties().put("rpc.max_queue_size", 100)
  179. # 本地运行执行下面代码,阿里云上不要参数
  180. runner = MyProcessor(worker_threads=20,worker_processes=1,endpoint="0.0.0.0:15000")
  181. #PAI平台运行
  182. #runner = MyProcessor()
  183. runner.run()
  184. if __name__ == '__main__':
  185. main()
  186. # paramter worker_threads indicates concurrency of processing
  187. #本地运行
  188. # tf.app.run()
  189. # allspark.default_properties().put("rpc.keepalive", 60000)
  190. # runner = MyProcessor(worker_threads=5,worker_processes=1,endpoint="0.0.0.0:15011")
  191. #
  192. #
  193. # #PAI平台运行
  194. # # runner = MyProcessor()
  195. # runner.run()