article_extract.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2020/4/24 0024 15:20
  5. # coding=utf-8
  6. # evaluate为该方法的入口函数,必须用这个名字
  7. from odps.udf import annotate
  8. from odps.distcache import get_cache_archive
  9. from odps.distcache import get_cache_file
  10. import time
  11. def recall(y_true, y_pred):
  12. '''
  13. 计算召回率
  14. @Argus:
  15. y_true: 正确的标签
  16. y_pred: 模型预测的标签
  17. @Return
  18. 召回率
  19. '''
  20. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  21. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  22. if c3 == 0:
  23. return 0
  24. recall = c1 / c3
  25. return recall
  26. def f1_score(y_true, y_pred):
  27. '''
  28. 计算F1
  29. @Argus:
  30. y_true: 正确的标签
  31. y_pred: 模型预测的标签
  32. @Return
  33. F1值
  34. '''
  35. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  36. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  37. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  38. precision = c1 / c2
  39. if c3 == 0:
  40. recall = 0
  41. else:
  42. recall = c1 / c3
  43. f1_score = 2 * (precision * recall) / (precision + recall)
  44. return f1_score
  45. def precision(y_true, y_pred):
  46. '''
  47. 计算精确率
  48. @Argus:
  49. y_true: 正确的标签
  50. y_pred: 模型预测的标签
  51. @Return
  52. 精确率
  53. '''
  54. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  55. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  56. precision = c1 / c2
  57. return precision
  58. # 配置pandas依赖包
  59. def include_package_path(res_name):
  60. import os, sys
  61. archive_files = get_cache_archive(res_name)
  62. dir_names = sorted([os.path.dirname(os.path.normpath(f.name)) for f in archive_files
  63. if '.dist_info' not in f.name], key=lambda v: len(v))
  64. sys.path.append(dir_names[0])
  65. return os.path.dirname(dir_names[0])
  66. # 初始化业务数据包,由于上传限制,python版本以及archive解压包不统一等各种问题,需要手动导入
  67. def init_env(list_files, package_name):
  68. import os, sys
  69. if len(list_files) == 1:
  70. so_file = get_cache_file(list_files[0])
  71. cmd_line = os.path.abspath(so_file.name)
  72. os.system("unzip %s -d %s" % (cmd_line, package_name))
  73. elif len(list_files) > 1:
  74. cmd_line = "cat"
  75. for _file in list_files:
  76. so_file = get_cache_file(_file)
  77. cmd_line += " " + os.path.abspath(so_file.name)
  78. cmd_line += " > temp.zip"
  79. os.system(cmd_line)
  80. os.system("unzip temp.zip -d %s" % (package_name))
  81. sys.path.append(os.path.abspath(package_name))
  82. # UDF主程序
  83. @annotate("string->string")
  84. class Extractor(object):
  85. def __init__(self):
  86. import logging as log
  87. global log
  88. import os
  89. log.basicConfig(level=log.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  90. logger = log.getLogger(__name__)
  91. model_path = os.path.abspath(get_cache_file('model_changemedium_acc90.model').name) # attentiongruacc0.932.model改为 New_attentionGUR_embed100_newlabel_20201020.h5 20201023
  92. log.info('model_path:%s'%model_path)
  93. log.info(os.path.exists(model_path))
  94. # init_env(['pyhanlp.z01', 'pyhanlp.z02','pyhanlp.z03','pyhanlp.z04'], 'pyhanlp')
  95. start_time = time.time()
  96. init_env(['pyhanlp.z01', 'pyhanlp.z02'], 'pyhanlp')
  97. log.info("init pyhanlp takes:%d"%(time.time()-start_time))
  98. start_time = time.time()
  99. # init_env(['envs_py37.zip.env'], 'envs_py37')
  100. # include_package_path("envs_py37.env.zip")
  101. include_package_path("envs_py37.env.zip")
  102. log.info("init envs_py37 takes:%d"%(time.time()-start_time))
  103. start_time = time.time()
  104. init_env(['so.env'], '.')
  105. init_env(['pkl_csv.z01'], '.')
  106. log.info("init pkl_csv takes:%d"%(time.time()-start_time))
  107. start_time = time.time()
  108. import pickle
  109. import csv
  110. import re as re
  111. import tensorflow as tf
  112. import numpy as np
  113. import keras.backend as K
  114. from keras import models
  115. from keras.engine.topology import Layer
  116. import json as json
  117. global json
  118. global re
  119. global np
  120. global tf,K
  121. log.info('import package done------------------')
  122. # dirpath = os.path.abspath('pyhanlp')
  123. # path = dirpath+'/pyhanlp/static/__init__.py' # return dirpath
  124. # dirpath = os.path.dirname(os.path.abspath(get_cache_file('pyhanlp.z01').name))
  125. # return '; '.join([a for a in os.listdir(os.listdir(dirpath)[0])])
  126. # path2 = os.path.abspath(get_cache_file('hanlpinit.txt').name)
  127. # content = []
  128. # with open(path2, encoding='utf-8') as f:
  129. # for line in f:
  130. # content.append(line)
  131. # # return '; '.join(content)
  132. # with open(path, 'w', encoding='utf-8') as f:
  133. # f.writelines(content)
  134. # log.info('rewrite hanlp path done--------------------')
  135. # archive_files = get_cache_archive('token_stopwds.zip')
  136. # names = [os.path.dirname(os.path.normpath(f.name)) for f in archive_files]
  137. # with open(names[0]+'/bidi_classify_stop_words.csv', 'r', encoding='utf-8') as f:
  138. # self.stopwords = [row[0] for row in csv.reader(f)]
  139. # with open(names[0]+'/word_index_955871.pk', 'rb') as f:
  140. # self.word_index = pickle.load(f)
  141. from pyhanlp import HanLP, JClass
  142. HanLP.Config = JClass('com.hankcs.hanlp.HanLP$Config')
  143. HanLP.Config.ShowTermNature = False
  144. self.hanlp = HanLP
  145. log.info('import hanlp done---------------------')
  146. class Attention(Layer):
  147. log.info('******attention****************')
  148. print('-------attention------------------')
  149. def __init__(self, **kwargs):
  150. super(Attention, self).__init__(**kwargs)
  151. def build(self, input_shape):
  152. # W: (EMBED_SIZE, 1)
  153. # b: (MAX_TIMESTEPS, 1)
  154. # u: (MAX_TIMESTEPS, MAX_TIMESTEPS)
  155. self.W = self.add_weight(name="W_{:s}".format(self.name),
  156. shape=(input_shape[-1], 1),
  157. initializer="normal")
  158. self.b = self.add_weight(name="b_{:s}".format(self.name),
  159. shape=(input_shape[1], 1),
  160. initializer="zeros")
  161. self.u = self.add_weight(name="u_{:s}".format(self.name),
  162. shape=(input_shape[1], input_shape[1]),
  163. initializer="normal")
  164. super(Attention, self).build(input_shape)
  165. def call(self, x, mask=None):
  166. # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  167. # et: (BATCH_SIZE, MAX_TIMESTEPS)
  168. et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=-1)
  169. # at: (BATCH_SIZE, MAX_TIMESTEPS)
  170. at = K.dot(et, self.u)
  171. at = K.exp(at)
  172. if mask is not None:
  173. at *= K.cast(mask, K.floatx())
  174. # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  175. at /= K.cast(K.sum(at, axis=1, keepdims=True) + K.epsilon(), K.floatx())
  176. atx = K.expand_dims(at, axis=-1)
  177. ot = atx * x
  178. # output: (BATCH_SIZE, EMBED_SIZE)
  179. return K.sum(ot, axis=1)
  180. def compute_mask(self, input, input_mask=None):
  181. # do not pass the mask to the next layers
  182. return None
  183. def compute_output_shape(self, input_shape):
  184. # output shape: (BATCH_SIZE, EMBED_SIZE)
  185. return (input_shape[0], input_shape[-1])
  186. def get_config(self):
  187. return super(Attention, self).get_config()
  188. self.model = models.load_model(model_path,
  189. custom_objects={'precision': precision,
  190. 'recall': recall,
  191. 'f1_score': f1_score,
  192. 'Attention': Attention})
  193. log.info('init model end --')
  194. pk_path = os.path.abspath('pkl_csv')
  195. with open(pk_path + '/id2label.pkl', 'rb') as f: # '/label_mapping210.pkl' 改为 id2label.pkl 20201023
  196. self.label_map = pickle.load(f)
  197. print('load label_map done')
  198. with open(pk_path + '/bidi_classify_stop_words.csv', 'r', encoding='utf-8') as f:
  199. self.stopwords = [row[0] for row in csv.reader(f)]
  200. with open(pk_path + '/word_index_955871.pk', 'rb') as f:
  201. self.word_index = pickle.load(f)
  202. with open(pk_path + '/class2dalei_menlei.pkl', 'rb') as f: # class_subclass_dic211.pk 改为 class2dalei_menlei.pkl 20201023
  203. self.class_dic = pickle.load(f)
  204. log.info('classs init done ----')
  205. def evaluate(self, text):
  206. # 去除html标签
  207. text = re.sub('\s', '', str(text))
  208. text = re.sub('<\s*script[^>]*>.*?<\s*/\s*script\s*>', '', text)
  209. text = re.sub('<\s*style[^>]*>.*?<\s*/\s*style\s*>', '', text)
  210. text = re.sub('</?\w+[^>]*>', '', text)
  211. # 清除干扰字符(英文、日期、数字、标点符号), 返回前500字
  212. text = re.sub('\{.*font.*\}|\{.*Font.*\}|[^\u4e00-\u9fa5]', '', text)[:500]
  213. # hanlp分词
  214. result = self.hanlp.segment(text)
  215. text_list = [str(result.get(i)) for i in range(result.size())]
  216. # 过滤停用词
  217. #text_list = [word for word in text_list if word not in self.stopwords and len(word) > 1] # 取消停用词过滤 20201023
  218. # 顺序去重
  219. #l2 = []
  220. #[l2.append(i) for i in text_list if i not in l2] # 取消顺序去重 20201023
  221. # 数字化
  222. text_list = [str(self.word_index.get(word, 0)) for word in text_list] # l2 改为text_list 20201023
  223. # padding and trans to array
  224. text_list = text_list[:150] if len(text_list) > 150 else text_list + ['0'] * (150 - len(text_list)) # 由原来100个词改为150个词 20201023
  225. features = np.array([text_list[:150] if len(text_list) > 150 else text_list + [0] * (150 - len(text_list))]) # 由原来100个词改为150个词 20201023
  226. log.info('数字化结束-------------------')
  227. # features = np.array([s.split(',')[:100] if len(s.split(','))>100 else s.split(',')+[0]*(100-len(s.split(',')))])
  228. with tf.get_default_graph().as_default():
  229. log.info('准备预测-------------------')
  230. logits = self.model.predict(features)
  231. # return ','.join(logits[0])
  232. # result = self.label_map(np.argmax(logits[0]))
  233. # return result
  234. log.info('预测结束-------------------')
  235. top3 = np.argsort(-logits[0], axis=-1)[:3]
  236. prob = ['%.4f' % (logits[0][i]) for i in top3]
  237. pre = [self.label_map[i] for i in top3]
  238. rd = {}
  239. i = 1
  240. for a in pre:
  241. sub, father = self.class_dic[a].split(',')
  242. rd['top' + str(i)] = {'subclass': sub, 'class_name': a, 'class': father}
  243. i += 1
  244. log.info('准备返回字符串')
  245. return json.dumps(rd,ensure_ascii=False)