123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- '''
- Created on 2019年12月3日
- @author: User
- '''
- import allspark
- import sys
- import os
- import json
- import logging
- import time
- import uuid
- sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
- import tensorflow as tf
- from text_classifier_pai.main import Text_Classifier
- import numpy as np
- import ctypes
- import inspect
- from threading import Thread
- import traceback
- # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- # os.environ["CUDA_VISIBLE_DEVICES"] = ""
- sys.path.append(os.path.abspath("."))
- #自定义jsonEncoder
- class MyEncoder(json.JSONEncoder):
- def default(self, obj):
- if isinstance(obj, np.ndarray):
- return obj.tolist()
- elif isinstance(obj, bytes):
- return str(obj, encoding='utf-8')
- elif isinstance(obj, (np.float_, np.float16, np.float32,
- np.float64)):
- return float(obj)
- return json.JSONEncoder.default(self, obj)
- def _async_raise(tid, exctype):
- """raises the exception, performs cleanup if needed"""
- tid = ctypes.c_long(tid)
- if not inspect.isclass(exctype):
- exctype = type(exctype)
- res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
- if res == 0:
- raise ValueError("invalid thread id")
- elif res != 1:
- ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
- raise SystemError("PyThreadState_SetAsyncExc failed")
- def stop_thread(thread):
- _async_raise(thread.ident, SystemExit)
- class MyProcessor(allspark.BaseProcessor):
- """ MyProcessor is a example
- you can send mesage like this to predict
- curl -v http://127.0.0.1:8080/api/predict/service_name -d '2 105'
- """
- def run_thread(self,data,list_result):
- # data = data.decode("utf8")
- # data = json.loads(data,encoding="utf8")
- # print('准备处理请求:')
- k = str(uuid.uuid4())
- cost_time = dict()
- if "doc_id" in data:
- _doc_id = data['doc_id']
- else:
- _doc_id = ""
- if "title" in data:
- _title = data["title"]
- else:
- _title = ""
- data_res = ""
- try:
- if "content" in data:
- # logging.info("get request of doc_id:%s"%(_doc_id))
- k = str(uuid.uuid4())
- cost_time = dict()
- content = data['content']
- start_time = time.time()
- # print('准备预处理,文章内容:',content[:20])
- process, ids = self.classifier.process([content])
- # logging.info("get preprocessed done of doc_id%s"%(_doc_id))
- # print('预处理完成')
- cost_time["preprocess"] = time.time()-start_time
- # cost_time.update(_cost_time)
- start_time = time.time()
- # print('开始预测')
- # with self.classifier.sess.graph.as_default():
- logits, ids = self.classifier.predict(process, ids)
- # print('预测完成')
- # logging.info("get predict done of doc_id%s"%(_doc_id))
- cost_time["predict"] = time.time()-start_time
- start_time = time.time()
- # print('准备提取结果')
- result = self.classifier.get_results(logits, ids)
- class_name = result[0][1] # 得到预测出来的分类名称
- subclass, topclass = self.classifier.dic_label[class_name].split(',') # 根据名称查找大类和小类名称
- # print('返回类别成功')
- # logging.info("get result done of doc_id%s"%(_doc_id))
- cost_time["result"] = time.time()-start_time
- data_res = {"class":topclass, "class_name":class_name, "subclass":subclass}
- data_res["success"] = True
- data_res["cost_time"] = cost_time
- #print(prem)
- # data_res = {'predict':result[0][1]}
- # data_res["cost_time"] = cost_time
- # data_res["success"] = True
- #return json.dumps(Preprocessing.union_result(codeName, prem)[0][1],cls=MyEncoder,sort_keys=True,indent=4,ensure_ascii=False)
- else:
- data_res = {"success":False,"msg":"content not passed"}
- except Exception as e:
- traceback.print_exc()
- data_res = {"success":False,"msg":str(e)}
- logging.error('Exception:%s'%str(e))
- # 以json形式返回结果
- #_resp = json.dumps(data_res,cls=MyEncoder)
- #print(str(data["flag"])+str(data))
- logging.info("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
- list_result.append(data_res)
- def initialize(self):
- """ load module, executed once at the start of the service
- do service intialization and load models in this function.
- """'''
- '''
- self.classifier = Text_Classifier()
- self.timeout = 60
- self.status_types = 5
- self.timeOfType = self.timeout//self.status_types
- logging.info('初始化完成, 服务端口15000')
- print('初始化完成, 服务端口15000')
-
-
- def pre_proccess(self, data):
- """ data format pre process
- """
- x, y = data.split(b' ')
- return int(x), int(y)
- def post_process(self, data):
- """ proccess after process
- """
- return bytes(data, encoding='utf8')
-
-
- def process(self, data):
- """ process the request data
- """
- try:
- data = data.decode("utf8")
- data = json.loads(data,encoding="utf8")
- _timeout = self.timeout
- status_code = 200
- if "timeout" in data:
- _timeout = data["timeout"]
- list_result = []
- t = Thread(target=self.run_thread,args=(data,list_result))
- start_time = time.time()
- t.start()
- t.join(_timeout)
- if t.is_alive():
- stop_thread(t)
- status_code = 302#超时被kill
- data_res = {"success":False,"msg":"timeout"}
- else:
- status_code += int((time.time()-start_time)//self.timeOfType+1)
- data_res = list_result[0]
- _resp = json.dumps(data_res,cls=MyEncoder)
- return self.post_process(_resp),status_code
- except Exception as e:
- pass
- return self.post_process(json.dumps({},cls=MyEncoder)),200
- def main():
- # 创建一个logging对象
- logger = logging.getLogger()
- # 创建一个文件对象
- fh = logging.FileHandler('log_dir/esa_classifier_pai.log', encoding='utf-8')
- # 创建一个屏幕对象
- sh = logging.StreamHandler()
- # 配置显示格式
- formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
- fh.setFormatter(formatter) # 把格式绑定到两个对象上
- sh.setFormatter(formatter)
- logger.addHandler(fh) # 将两个句柄绑定到logger
- logger.addHandler(sh)
- logger.setLevel(10)
- fh.setLevel(10)
- sh.setLevel(30)
- allspark.default_properties().put("rpc.keepalive", 250000)
- allspark.default_properties().put("rpc.max_queue_size", 100)
- # 本地运行执行下面代码,阿里云上不要参数
- runner = MyProcessor(worker_threads=20,worker_processes=1,endpoint="0.0.0.0:15000")
- #PAI平台运行
- #runner = MyProcessor()
- runner.run()
-
- if __name__ == '__main__':
- main()
- # paramter worker_threads indicates concurrency of processing
- #本地运行
- # tf.app.run()
- # allspark.default_properties().put("rpc.keepalive", 60000)
- # runner = MyProcessor(worker_threads=5,worker_processes=1,endpoint="0.0.0.0:15011")
- #
- #
- # #PAI平台运行
- # # runner = MyProcessor()
- # runner.run()
|