''' Created on 2019年12月3日 @author: User ''' import allspark import sys import os import json import logging import time import uuid sys.path.append(os.path.abspath(os.path.dirname(os.getcwd()))) import tensorflow as tf from text_classifier_pai.main import Text_Classifier import numpy as np import ctypes import inspect from threading import Thread import traceback # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = "" sys.path.append(os.path.abspath(".")) #自定义jsonEncoder class MyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, bytes): return str(obj, encoding='utf-8') elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): return float(obj) return json.JSONEncoder.default(self, obj) def _async_raise(tid, exctype): """raises the exception, performs cleanup if needed""" tid = ctypes.c_long(tid) if not inspect.isclass(exctype): exctype = type(exctype) res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype)) if res == 0: raise ValueError("invalid thread id") elif res != 1: ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) raise SystemError("PyThreadState_SetAsyncExc failed") def stop_thread(thread): _async_raise(thread.ident, SystemExit) class MyProcessor(allspark.BaseProcessor): """ MyProcessor is a example you can send mesage like this to predict curl -v http://127.0.0.1:8080/api/predict/service_name -d '2 105' """ def run_thread(self,data,list_result): # data = data.decode("utf8") # data = json.loads(data,encoding="utf8") # print('准备处理请求:') k = str(uuid.uuid4()) cost_time = dict() if "doc_id" in data: _doc_id = data['doc_id'] else: _doc_id = "" if "title" in data: _title = data["title"] else: _title = "" data_res = "" try: if "content" in data: # logging.info("get request of doc_id:%s"%(_doc_id)) k = str(uuid.uuid4()) cost_time = dict() content = data['content'] start_time = time.time() # print('准备预处理,文章内容:',content[:20]) process, ids = self.classifier.process([content]) # logging.info("get preprocessed done of doc_id%s"%(_doc_id)) # print('预处理完成') cost_time["preprocess"] = time.time()-start_time # cost_time.update(_cost_time) start_time = time.time() # print('开始预测') # with self.classifier.sess.graph.as_default(): logits, ids = self.classifier.predict(process, ids) # print('预测完成') # logging.info("get predict done of doc_id%s"%(_doc_id)) cost_time["predict"] = time.time()-start_time start_time = time.time() # print('准备提取结果') result = self.classifier.get_results(logits, ids) class_name = result[0][1] # 得到预测出来的分类名称 subclass, topclass = self.classifier.dic_label[class_name].split(',') # 根据名称查找大类和小类名称 # print('返回类别成功') # logging.info("get result done of doc_id%s"%(_doc_id)) cost_time["result"] = time.time()-start_time data_res = {"class":topclass, "class_name":class_name, "subclass":subclass} data_res["success"] = True data_res["cost_time"] = cost_time #print(prem) # data_res = {'predict':result[0][1]} # data_res["cost_time"] = cost_time # data_res["success"] = True #return json.dumps(Preprocessing.union_result(codeName, prem)[0][1],cls=MyEncoder,sort_keys=True,indent=4,ensure_ascii=False) else: data_res = {"success":False,"msg":"content not passed"} except Exception as e: traceback.print_exc() data_res = {"success":False,"msg":str(e)} logging.error('Exception:%s'%str(e)) # 以json形式返回结果 #_resp = json.dumps(data_res,cls=MyEncoder) #print(str(data["flag"])+str(data)) logging.info("done for doc_id:%s with result:%s"%(_doc_id,str(data_res))) list_result.append(data_res) def initialize(self): """ load module, executed once at the start of the service do service intialization and load models in this function. """''' ''' self.classifier = Text_Classifier() self.timeout = 60 self.status_types = 5 self.timeOfType = self.timeout//self.status_types logging.info('初始化完成, 服务端口15000') print('初始化完成, 服务端口15000') def pre_proccess(self, data): """ data format pre process """ x, y = data.split(b' ') return int(x), int(y) def post_process(self, data): """ proccess after process """ return bytes(data, encoding='utf8') def process(self, data): """ process the request data """ try: data = data.decode("utf8") data = json.loads(data,encoding="utf8") _timeout = self.timeout status_code = 200 if "timeout" in data: _timeout = data["timeout"] list_result = [] t = Thread(target=self.run_thread,args=(data,list_result)) start_time = time.time() t.start() t.join(_timeout) if t.is_alive(): stop_thread(t) status_code = 302#超时被kill data_res = {"success":False,"msg":"timeout"} else: status_code += int((time.time()-start_time)//self.timeOfType+1) data_res = list_result[0] _resp = json.dumps(data_res,cls=MyEncoder) return self.post_process(_resp),status_code except Exception as e: pass return self.post_process(json.dumps({},cls=MyEncoder)),200 def main(): # 创建一个logging对象 logger = logging.getLogger() # 创建一个文件对象 fh = logging.FileHandler('log_dir/esa_classifier_pai.log', encoding='utf-8') # 创建一个屏幕对象 sh = logging.StreamHandler() # 配置显示格式 formatter = logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s') fh.setFormatter(formatter) # 把格式绑定到两个对象上 sh.setFormatter(formatter) logger.addHandler(fh) # 将两个句柄绑定到logger logger.addHandler(sh) logger.setLevel(10) fh.setLevel(10) sh.setLevel(30) allspark.default_properties().put("rpc.keepalive", 250000) allspark.default_properties().put("rpc.max_queue_size", 100) # 本地运行执行下面代码,阿里云上不要参数 runner = MyProcessor(worker_threads=20,worker_processes=1,endpoint="0.0.0.0:15000") #PAI平台运行 #runner = MyProcessor() runner.run() if __name__ == '__main__': main() # paramter worker_threads indicates concurrency of processing #本地运行 # tf.app.run() # allspark.default_properties().put("rpc.keepalive", 60000) # runner = MyProcessor(worker_threads=5,worker_processes=1,endpoint="0.0.0.0:15011") # # # #PAI平台运行 # # runner = MyProcessor() # runner.run()