article_extract.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2020/4/24 0024 15:20
  5. # coding=utf-8
  6. # evaluate为该方法的入口函数,必须用这个名字
  7. from odps.udf import annotate
  8. from odps.distcache import get_cache_archive
  9. from odps.distcache import get_cache_file
  10. def recall(y_true, y_pred):
  11. '''
  12. 计算召回率
  13. @Argus:
  14. y_true: 正确的标签
  15. y_pred: 模型预测的标签
  16. @Return
  17. 召回率
  18. '''
  19. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  20. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  21. if c3 == 0:
  22. return 0
  23. recall = c1 / c3
  24. return recall
  25. def f1_score(y_true, y_pred):
  26. '''
  27. 计算F1
  28. @Argus:
  29. y_true: 正确的标签
  30. y_pred: 模型预测的标签
  31. @Return
  32. F1值
  33. '''
  34. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  35. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  36. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  37. precision = c1 / c2
  38. if c3 == 0:
  39. recall = 0
  40. else:
  41. recall = c1 / c3
  42. f1_score = 2 * (precision * recall) / (precision + recall)
  43. return f1_score
  44. def precision(y_true, y_pred):
  45. '''
  46. 计算精确率
  47. @Argus:
  48. y_true: 正确的标签
  49. y_pred: 模型预测的标签
  50. @Return
  51. 精确率
  52. '''
  53. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  54. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  55. precision = c1 / c2
  56. return precision
  57. # 配置pandas依赖包
  58. def include_package_path(res_name):
  59. import os, sys
  60. archive_files = get_cache_archive(res_name)
  61. dir_names = sorted([os.path.dirname(os.path.normpath(f.name)) for f in archive_files
  62. if '.dist_info' not in f.name], key=lambda v: len(v))
  63. sys.path.append(dir_names[0])
  64. return os.path.dirname(dir_names[0])
  65. # 初始化业务数据包,由于上传限制,python版本以及archive解压包不统一等各种问题,需要手动导入
  66. def init_env(list_files, package_name):
  67. import os, sys
  68. if len(list_files) == 1:
  69. so_file = get_cache_file(list_files[0])
  70. cmd_line = os.path.abspath(so_file.name)
  71. os.system("unzip %s -d %s" % (cmd_line, package_name))
  72. elif len(list_files) > 1:
  73. cmd_line = "cat"
  74. for _file in list_files:
  75. so_file = get_cache_file(_file)
  76. cmd_line += " " + os.path.abspath(so_file.name)
  77. cmd_line += " > temp.zip"
  78. os.system(cmd_line)
  79. os.system("unzip temp.zip -d %s" % (package_name))
  80. sys.path.append(os.path.abspath(package_name))
  81. # UDF主程序
  82. @annotate("string->string")
  83. class Extractor(object):
  84. def __init__(self):
  85. import logging as log
  86. global log
  87. import os
  88. log.basicConfig(level=log.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  89. logger = log.getLogger(__name__)
  90. model_path = os.path.abspath(get_cache_file('attentiongruacc0.932.model').name)
  91. log.info('model_path', model_path)
  92. log.info(os.path.exists(model_path))
  93. # init_env(['pyhanlp.z01', 'pyhanlp.z02','pyhanlp.z03','pyhanlp.z04'], 'pyhanlp')
  94. init_env(['pyhanlp.z01', 'pyhanlp.z02'], 'pyhanlp')
  95. # init_env(['envs_py37.zip.env'], 'envs_py37')
  96. include_package_path("envs_py37.env.zip")
  97. init_env(['so.env'], '.')
  98. init_env(['pkl_csv.z01'], '.')
  99. import pickle
  100. import csv
  101. import re as re
  102. import tensorflow as tf
  103. import numpy as np
  104. import keras.backend as K
  105. from keras import models
  106. from keras.engine.topology import Layer
  107. import json as json
  108. global json
  109. global re
  110. global np
  111. global tf
  112. log.info('import package done------------------')
  113. # dirpath = os.path.abspath('pyhanlp')
  114. # path = dirpath+'/pyhanlp/static/__init__.py' # return dirpath
  115. # dirpath = os.path.dirname(os.path.abspath(get_cache_file('pyhanlp.z01').name))
  116. # return '; '.join([a for a in os.listdir(os.listdir(dirpath)[0])])
  117. # path2 = os.path.abspath(get_cache_file('hanlpinit.txt').name)
  118. # content = []
  119. # with open(path2, encoding='utf-8') as f:
  120. # for line in f:
  121. # content.append(line)
  122. # # return '; '.join(content)
  123. # with open(path, 'w', encoding='utf-8') as f:
  124. # f.writelines(content)
  125. # log.info('rewrite hanlp path done--------------------')
  126. # archive_files = get_cache_archive('token_stopwds.zip')
  127. # names = [os.path.dirname(os.path.normpath(f.name)) for f in archive_files]
  128. # with open(names[0]+'/bidi_classify_stop_words.csv', 'r', encoding='utf-8') as f:
  129. # self.stopwords = [row[0] for row in csv.reader(f)]
  130. # with open(names[0]+'/word_index_955871.pk', 'rb') as f:
  131. # self.word_index = pickle.load(f)
  132. from pyhanlp import HanLP, JClass
  133. HanLP.Config = JClass('com.hankcs.hanlp.HanLP$Config')
  134. HanLP.Config.ShowTermNature = False
  135. self.hanlp = HanLP
  136. log.info('import hanlp done---------------------')
  137. class Attention(Layer):
  138. log.info('******attention****************')
  139. print('-------attention------------------')
  140. def __init__(self, **kwargs):
  141. super(Attention, self).__init__(**kwargs)
  142. def build(self, input_shape):
  143. # W: (EMBED_SIZE, 1)
  144. # b: (MAX_TIMESTEPS, 1)
  145. # u: (MAX_TIMESTEPS, MAX_TIMESTEPS)
  146. self.W = self.add_weight(name="W_{:s}".format(self.name),
  147. shape=(input_shape[-1], 1),
  148. initializer="normal")
  149. self.b = self.add_weight(name="b_{:s}".format(self.name),
  150. shape=(input_shape[1], 1),
  151. initializer="zeros")
  152. self.u = self.add_weight(name="u_{:s}".format(self.name),
  153. shape=(input_shape[1], input_shape[1]),
  154. initializer="normal")
  155. super(Attention, self).build(input_shape)
  156. def call(self, x, mask=None):
  157. # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  158. # et: (BATCH_SIZE, MAX_TIMESTEPS)
  159. et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=-1)
  160. # at: (BATCH_SIZE, MAX_TIMESTEPS)
  161. at = K.dot(et, self.u)
  162. at = K.exp(at)
  163. if mask is not None:
  164. at *= K.cast(mask, K.floatx())
  165. # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  166. at /= K.cast(K.sum(at, axis=1, keepdims=True) + K.epsilon(), K.floatx())
  167. atx = K.expand_dims(at, axis=-1)
  168. ot = atx * x
  169. # output: (BATCH_SIZE, EMBED_SIZE)
  170. return K.sum(ot, axis=1)
  171. def compute_mask(self, input, input_mask=None):
  172. # do not pass the mask to the next layers
  173. return None
  174. def compute_output_shape(self, input_shape):
  175. # output shape: (BATCH_SIZE, EMBED_SIZE)
  176. return (input_shape[0], input_shape[-1])
  177. def get_config(self):
  178. return super(Attention, self).get_config()
  179. self.model = models.load_model(model_path,
  180. custom_objects={'precision': precision,
  181. 'recall': recall,
  182. 'f1_score': f1_score,
  183. 'Attention': Attention})
  184. log.info('init model end --')
  185. pk_path = os.path.abspath('pkl_csv')
  186. with open(pk_path + '/label_mapping210.pkl', 'rb') as f:
  187. self.label_map = pickle.load(f)
  188. print('load label_map done')
  189. with open(pk_path + '/bidi_classify_stop_words.csv', 'r', encoding='utf-8') as f:
  190. self.stopwords = [row[0] for row in csv.reader(f)]
  191. with open(pk_path + '/word_index_955871.pk', 'rb') as f:
  192. self.word_index = pickle.load(f)
  193. with open(pk_path + '/class_subclass_dic211.pk', 'rb') as f:
  194. self.class_dic = pickle.load(f)
  195. log.info('classs init done ----')
  196. def evaluate(self, text):
  197. # 去除html标签
  198. text = re.sub('\s', '', str(text))
  199. text = re.sub('<\s*script[^>]*>.*?<\s*/\s*script\s*>', '', text)
  200. text = re.sub('<\s*style[^>]*>.*?<\s*/\s*style\s*>', '', text)
  201. text = re.sub('</?\w+[^>]*>', '', text)
  202. # 清除干扰字符(英文、日期、数字、标点符号), 返回前500字
  203. text = re.sub('\{.*font.*\}|\{.*Font.*\}|[^\u4e00-\u9fa5]', '', text)[:500]
  204. # hanlp分词
  205. result = self.hanlp.segment(text)
  206. text_list = [str(result.get(i)) for i in range(result.size())]
  207. # 过滤停用词
  208. text_list = [word for word in text_list if word not in self.stopwords and len(word) > 1]
  209. # 顺序去重
  210. l2 = []
  211. [l2.append(i) for i in text_list if i not in l2]
  212. # 数字化
  213. text_list = [str(self.word_index.get(word, 0)) for word in l2]
  214. # padding and trans to array
  215. text_list = text_list[:100] if len(text_list) > 100 else text_list + ['0'] * (100 - len(text_list))
  216. features = np.array([text_list[:100] if len(text_list) > 100 else text_list + [0] * (100 - len(text_list))])
  217. log.info('数字化结束-------------------')
  218. # features = np.array([s.split(',')[:100] if len(s.split(','))>100 else s.split(',')+[0]*(100-len(s.split(',')))])
  219. with tf.get_default_graph().as_default():
  220. log.info('准备预测-------------------')
  221. logits = self.model.predict(features)
  222. # return ','.join(logits[0])
  223. # result = self.label_map(np.argmax(logits[0]))
  224. # return result
  225. log.info('预测结束-------------------')
  226. top3 = np.argsort(-logits[0], axis=-1)[:3]
  227. prob = ['%.4f' % (logits[0][i]) for i in top3]
  228. pre = [self.label_map[i] for i in top3]
  229. rd = {}
  230. i = 1
  231. for a in pre:
  232. sub, father = self.class_dic[a].split(',')
  233. rd['top' + str(i)] = {'subclass': sub, 'class_name': a, 'class': father}
  234. i += 1
  235. log.info('准备返回字符串')
  236. return json.dumps(rd,ensure_ascii=False)