run_industry_server.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri Jun 1 18:03:03 2018
  4. @author: DONG
  5. """
  6. import sys
  7. import os
  8. from flask import Flask, jsonify
  9. from flask import abort
  10. from flask import request
  11. sys.path.append(os.path.dirname(__file__)+"/..")
  12. os.environ["KERAS_BACKEND"] = "tensorflow"
  13. app = Flask(__name__)
  14. app.config['JSON_AS_ASCII'] = False
  15. import logging
  16. logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  17. from text_classifier_pai.main import Text_Classifier
  18. import numpy as np
  19. import ctypes
  20. import inspect
  21. from threading import Thread
  22. import traceback
  23. import json
  24. import time
  25. import uuid
  26. import re
  27. from bs4 import BeautifulSoup
  28. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  29. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  30. sys.path.append(os.path.abspath("."))
  31. classifier = Text_Classifier()
  32. #自定义jsonEncoder
  33. class MyEncoder(json.JSONEncoder):
  34. def default(self, obj):
  35. if isinstance(obj, np.ndarray):
  36. return obj.tolist()
  37. elif isinstance(obj, bytes):
  38. return str(obj, encoding='utf-8')
  39. elif isinstance(obj, (np.float_, np.float16, np.float32,
  40. np.float64)):
  41. return float(obj)
  42. return json.JSONEncoder.default(self, obj)
  43. def _async_raise(tid, exctype):
  44. """raises the exception, performs cleanup if needed"""
  45. tid = ctypes.c_long(tid)
  46. if not inspect.isclass(exctype):
  47. exctype = type(exctype)
  48. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(exctype))
  49. if res == 0:
  50. raise ValueError("invalid thread id")
  51. elif res != 1:
  52. ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
  53. raise SystemError("PyThreadState_SetAsyncExc failed")
  54. def stop_thread(thread):
  55. _async_raise(thread.ident, SystemExit)
  56. def article_limit(soup,limit_words=30000):
  57. sub_space = re.compile("\s+")
  58. def soup_limit(_soup,_count,max_count=30000,max_gap=500):
  59. """
  60. :param _soup: soup
  61. :param _count: 当前字数
  62. :param max_count: 字数最大限制
  63. :param max_gap: 超过限制后的最大误差
  64. :return:
  65. """
  66. _gap = _count - max_count
  67. _is_skip = False
  68. next_soup = None
  69. while len(_soup.find_all(recursive=False)) == 1 and \
  70. _soup.get_text(strip=True) == _soup.find_all(recursive=False)[0].get_text(strip=True):
  71. _soup = _soup.find_all(recursive=False)[0]
  72. if len(_soup.find_all(recursive=False)) == 0:
  73. _soup.string = str(_soup.get_text())[:max_count-_count]
  74. _count += len(re.sub(sub_space, "", _soup.string))
  75. _gap = _count - max_count
  76. next_soup = None
  77. else:
  78. for _soup_part in _soup.find_all(recursive=False):
  79. if not _is_skip:
  80. _count += len(re.sub(sub_space, "", _soup_part.get_text()))
  81. if _count >= max_count:
  82. _gap = _count - max_count
  83. if _gap <= max_gap:
  84. _is_skip = True
  85. else:
  86. _is_skip = True
  87. next_soup = _soup_part
  88. _count -= len(re.sub(sub_space, "", _soup_part.get_text()))
  89. continue
  90. else:
  91. _soup_part.decompose()
  92. return _count,_gap,next_soup
  93. text_count = 0
  94. have_attachment = False
  95. attachment_part = None
  96. for child in soup.find_all(recursive=True):
  97. if child.name == 'div' and 'class' in child.attrs:
  98. if "richTextFetch" in child['class']:
  99. child.insert_before("##attachment##")
  100. attachment_part = child
  101. have_attachment = True
  102. break
  103. if not have_attachment:
  104. # 无附件
  105. if len(re.sub(sub_space, "", soup.get_text())) > limit_words:
  106. text_count,gap,n_soup = soup_limit(soup,text_count,max_count=limit_words,max_gap=500)
  107. while n_soup:
  108. text_count, gap, n_soup = soup_limit(n_soup, text_count, max_count=limit_words, max_gap=500)
  109. else:
  110. # 有附件
  111. _text = re.sub(sub_space, "", soup.get_text())
  112. _text_split = _text.split("##attachment##")
  113. if len(_text_split[0])>limit_words:
  114. main_soup = attachment_part.parent
  115. main_text = main_soup.find_all(recursive=False)[0]
  116. text_count, gap, n_soup = soup_limit(main_text, text_count, max_count=limit_words, max_gap=500)
  117. while n_soup:
  118. text_count, gap, n_soup = soup_limit(n_soup, text_count, max_count=limit_words, max_gap=500)
  119. if len(_text_split[1])>limit_words:
  120. # attachment_html纯文本,无子结构
  121. if len(attachment_part.find_all(recursive=False))==0:
  122. attachment_part.string = str(attachment_part.get_text())[:limit_words]
  123. else:
  124. attachment_text_nums = 0
  125. attachment_skip = False
  126. for part in attachment_part.find_all(recursive=False):
  127. if not attachment_skip:
  128. last_attachment_text_nums = attachment_text_nums
  129. attachment_text_nums = attachment_text_nums + len(re.sub(sub_space, "", part.get_text()))
  130. if attachment_text_nums>=limit_words:
  131. part.string = str(part.get_text())[:limit_words-last_attachment_text_nums]
  132. attachment_skip = True
  133. else:
  134. part.decompose()
  135. return soup
  136. def run_thread(data,list_result):
  137. # data = data.decode("utf8")
  138. # data = json.loads(data,encoding="utf8")
  139. k = str(uuid.uuid4())
  140. cost_time = dict()
  141. if "doc_id" in data:
  142. _doc_id = data['doc_id']
  143. else:
  144. _doc_id = ""
  145. if "title" in data:
  146. _title = data["title"]
  147. else:
  148. _title = ""
  149. data_res = ""
  150. try:
  151. if "content" in data:
  152. # logging.info("get request of doc_id:%s"%(_doc_id))
  153. k = str(uuid.uuid4())
  154. cost_time = dict()
  155. content = data['content']
  156. if len(content)>50000:
  157. _soup = BeautifulSoup(content,"lxml")
  158. _soup = article_limit(_soup,50000)
  159. content = str(_soup)
  160. start_time = time.time()
  161. # print('准备预处理,文章内容:',content[:20])
  162. process, ids = classifier.process([content])
  163. # logging.info("get preprocessed done of doc_id%s"%(_doc_id))
  164. # print('预处理完成')
  165. cost_time["preprocess"] = time.time()-start_time
  166. # cost_time.update(_cost_time)
  167. start_time = time.time()
  168. # print('开始预测')
  169. # with self.classifier.sess.graph.as_default():
  170. logits, ids = classifier.predict(process, ids)
  171. # print('预测完成')
  172. # logging.info("get predict done of doc_id%s"%(_doc_id))
  173. cost_time["predict"] = time.time()-start_time
  174. start_time = time.time()
  175. # print('准备提取结果')
  176. result = classifier.get_results(logits, ids)
  177. class_name = result[0][1] # 得到预测出来的分类名称
  178. subclass, topclass = classifier.dic_label[class_name].split(',') # 根据名称查找大类和小类名称
  179. # print('返回类别成功')
  180. # logging.info("get result done of doc_id%s"%(_doc_id))
  181. cost_time["result"] = time.time()-start_time
  182. data_res = {"class":topclass, "class_name":class_name, "subclass":subclass}
  183. data_res["success"] = True
  184. data_res["cost_time"] = cost_time
  185. #print(prem)
  186. # data_res = {'predict':result[0][1]}
  187. # data_res["cost_time"] = cost_time
  188. # data_res["success"] = True
  189. #return json.dumps(Preprocessing.union_result(codeName, prem)[0][1],cls=MyEncoder,sort_keys=True,indent=4,ensure_ascii=False)
  190. else:
  191. data_res = {"success":False,"msg":"content not passed"}
  192. except Exception as e:
  193. traceback.print_exc()
  194. data_res = {"success":False,"msg":str(e)}
  195. logging.error('Exception:%s'%str(e))
  196. # 以json形式返回结果
  197. #_resp = json.dumps(data_res,cls=MyEncoder)
  198. #print(str(data["flag"])+str(data))
  199. logging.info("done for doc_id:%s with result:%s"%(_doc_id,str(data_res)))
  200. list_result.append(data_res)
  201. @app.route('/industry_extract', methods=['POST'])
  202. def text_predict():
  203. try:
  204. data = request.json
  205. status_code = 200
  206. if "timeout" in data:
  207. _timeout = data["timeout"]
  208. list_result = []
  209. t = Thread(target=run_thread,args=(data,list_result))
  210. start_time = time.time()
  211. t.start()
  212. t.join(_timeout)
  213. if t.is_alive():
  214. stop_thread(t)
  215. status_code = 302#超时被kill
  216. data_res = {"success":False,"msg":"timeout"}
  217. else:
  218. data_res = list_result[0]
  219. _resp = json.dumps(data_res,cls=MyEncoder,ensure_ascii=False)
  220. return _resp,201
  221. except Exception as e:
  222. traceback.print_exc()
  223. data_res = {"success":False,"msg":"error:%s"%(str(e))}
  224. _resp = json.dumps(data_res)
  225. return _resp,500
  226. def getPort(argv):
  227. port = 15000
  228. for item in argv:
  229. _l = str(item).split("port=")
  230. if len(_l)>1:
  231. port = int(_l[-1])
  232. break
  233. return port
  234. if __name__ == '__main__':
  235. port = getPort(argv=sys.argv)
  236. app.run(host='0.0.0.0', port=port, threaded=True, debug=False)
  237. ("ContentExtractor running")
  238. # app.run()