|
@@ -0,0 +1,285 @@
|
|
|
+#!/usr/bin/python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# @Author : bidikeji
|
|
|
+# @Time : 2020/4/24 0024 15:20
|
|
|
+
|
|
|
+# coding=utf-8
|
|
|
+# evaluate为该方法的入口函数,必须用这个名字
|
|
|
+from odps.udf import annotate
|
|
|
+from odps.distcache import get_cache_archive
|
|
|
+from odps.distcache import get_cache_file
|
|
|
+import time
|
|
|
+
|
|
|
+
|
|
|
+def recall(y_true, y_pred):
|
|
|
+ '''
|
|
|
+ 计算召回率
|
|
|
+
|
|
|
+ @Argus:
|
|
|
+ y_true: 正确的标签
|
|
|
+ y_pred: 模型预测的标签
|
|
|
+
|
|
|
+ @Return
|
|
|
+ 召回率
|
|
|
+ '''
|
|
|
+ c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
|
|
|
+ c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
|
|
|
+ if c3 == 0:
|
|
|
+ return 0
|
|
|
+ recall = c1 / c3
|
|
|
+ return recall
|
|
|
+
|
|
|
+
|
|
|
+def f1_score(y_true, y_pred):
|
|
|
+ '''
|
|
|
+ 计算F1
|
|
|
+
|
|
|
+ @Argus:
|
|
|
+ y_true: 正确的标签
|
|
|
+ y_pred: 模型预测的标签
|
|
|
+
|
|
|
+ @Return
|
|
|
+ F1值
|
|
|
+ '''
|
|
|
+
|
|
|
+ c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
|
|
|
+ c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
|
|
|
+ c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
|
|
|
+ precision = c1 / c2
|
|
|
+ if c3 == 0:
|
|
|
+ recall = 0
|
|
|
+ else:
|
|
|
+ recall = c1 / c3
|
|
|
+ f1_score = 2 * (precision * recall) / (precision + recall)
|
|
|
+ return f1_score
|
|
|
+
|
|
|
+
|
|
|
+def precision(y_true, y_pred):
|
|
|
+ '''
|
|
|
+ 计算精确率
|
|
|
+
|
|
|
+ @Argus:
|
|
|
+ y_true: 正确的标签
|
|
|
+ y_pred: 模型预测的标签
|
|
|
+
|
|
|
+ @Return
|
|
|
+ 精确率
|
|
|
+ '''
|
|
|
+ c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
|
|
|
+ c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
|
|
|
+ precision = c1 / c2
|
|
|
+ return precision
|
|
|
+
|
|
|
+# 配置pandas依赖包
|
|
|
+def include_package_path(res_name):
|
|
|
+ import os, sys
|
|
|
+ archive_files = get_cache_archive(res_name)
|
|
|
+ dir_names = sorted([os.path.dirname(os.path.normpath(f.name)) for f in archive_files
|
|
|
+ if '.dist_info' not in f.name], key=lambda v: len(v))
|
|
|
+ sys.path.append(dir_names[0])
|
|
|
+
|
|
|
+ return os.path.dirname(dir_names[0])
|
|
|
+
|
|
|
+
|
|
|
+# 初始化业务数据包,由于上传限制,python版本以及archive解压包不统一等各种问题,需要手动导入
|
|
|
+def init_env(list_files, package_name):
|
|
|
+ import os, sys
|
|
|
+
|
|
|
+ if len(list_files) == 1:
|
|
|
+ so_file = get_cache_file(list_files[0])
|
|
|
+ cmd_line = os.path.abspath(so_file.name)
|
|
|
+ os.system("unzip %s -d %s" % (cmd_line, package_name))
|
|
|
+ elif len(list_files) > 1:
|
|
|
+ cmd_line = "cat"
|
|
|
+ for _file in list_files:
|
|
|
+ so_file = get_cache_file(_file)
|
|
|
+ cmd_line += " " + os.path.abspath(so_file.name)
|
|
|
+ cmd_line += " > temp.zip"
|
|
|
+ os.system(cmd_line)
|
|
|
+ os.system("unzip temp.zip -d %s" % (package_name))
|
|
|
+ sys.path.append(os.path.abspath(package_name))
|
|
|
+
|
|
|
+
|
|
|
+# UDF主程序
|
|
|
+@annotate("string->string")
|
|
|
+class Extractor(object):
|
|
|
+ def __init__(self):
|
|
|
+ import logging as log
|
|
|
+ global log
|
|
|
+ import os
|
|
|
+ log.basicConfig(level=log.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
+ logger = log.getLogger(__name__)
|
|
|
+
|
|
|
+ model_path = os.path.abspath(get_cache_file('model_changemedium_acc90.model').name) # attentiongruacc0.932.model改为 New_attentionGUR_embed100_newlabel_20201020.h5 20201023
|
|
|
+ log.info('model_path:%s'%model_path)
|
|
|
+ log.info(os.path.exists(model_path))
|
|
|
+
|
|
|
+ # init_env(['pyhanlp.z01', 'pyhanlp.z02','pyhanlp.z03','pyhanlp.z04'], 'pyhanlp')
|
|
|
+ start_time = time.time()
|
|
|
+ init_env(['pyhanlp.z01', 'pyhanlp.z02'], 'pyhanlp')
|
|
|
+ log.info("init pyhanlp takes:%d"%(time.time()-start_time))
|
|
|
+ start_time = time.time()
|
|
|
+ # init_env(['envs_py37.zip.env'], 'envs_py37')
|
|
|
+ # include_package_path("envs_py37.env.zip")
|
|
|
+ include_package_path("envs_py37.env.zip")
|
|
|
+ log.info("init envs_py37 takes:%d"%(time.time()-start_time))
|
|
|
+ start_time = time.time()
|
|
|
+ init_env(['so.env'], '.')
|
|
|
+ init_env(['pkl_csv.z01'], '.')
|
|
|
+ log.info("init pkl_csv takes:%d"%(time.time()-start_time))
|
|
|
+ start_time = time.time()
|
|
|
+ import pickle
|
|
|
+
|
|
|
+ import csv
|
|
|
+ import re as re
|
|
|
+ import tensorflow as tf
|
|
|
+ import numpy as np
|
|
|
+ import keras.backend as K
|
|
|
+ from keras import models
|
|
|
+ from keras.engine.topology import Layer
|
|
|
+
|
|
|
+ import json as json
|
|
|
+ global json
|
|
|
+ global re
|
|
|
+ global np
|
|
|
+ global tf,K
|
|
|
+
|
|
|
+
|
|
|
+ log.info('import package done------------------')
|
|
|
+ # dirpath = os.path.abspath('pyhanlp')
|
|
|
+ # path = dirpath+'/pyhanlp/static/__init__.py' # return dirpath
|
|
|
+ # dirpath = os.path.dirname(os.path.abspath(get_cache_file('pyhanlp.z01').name))
|
|
|
+ # return '; '.join([a for a in os.listdir(os.listdir(dirpath)[0])])
|
|
|
+ # path2 = os.path.abspath(get_cache_file('hanlpinit.txt').name)
|
|
|
+ # content = []
|
|
|
+ # with open(path2, encoding='utf-8') as f:
|
|
|
+ # for line in f:
|
|
|
+ # content.append(line)
|
|
|
+ # # return '; '.join(content)
|
|
|
+ # with open(path, 'w', encoding='utf-8') as f:
|
|
|
+ # f.writelines(content)
|
|
|
+ # log.info('rewrite hanlp path done--------------------')
|
|
|
+ # archive_files = get_cache_archive('token_stopwds.zip')
|
|
|
+ # names = [os.path.dirname(os.path.normpath(f.name)) for f in archive_files]
|
|
|
+ # with open(names[0]+'/bidi_classify_stop_words.csv', 'r', encoding='utf-8') as f:
|
|
|
+ # self.stopwords = [row[0] for row in csv.reader(f)]
|
|
|
+ # with open(names[0]+'/word_index_955871.pk', 'rb') as f:
|
|
|
+ # self.word_index = pickle.load(f)
|
|
|
+
|
|
|
+ from pyhanlp import HanLP, JClass
|
|
|
+ HanLP.Config = JClass('com.hankcs.hanlp.HanLP$Config')
|
|
|
+ HanLP.Config.ShowTermNature = False
|
|
|
+ self.hanlp = HanLP
|
|
|
+ log.info('import hanlp done---------------------')
|
|
|
+
|
|
|
+ class Attention(Layer):
|
|
|
+ log.info('******attention****************')
|
|
|
+ print('-------attention------------------')
|
|
|
+
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ super(Attention, self).__init__(**kwargs)
|
|
|
+
|
|
|
+ def build(self, input_shape):
|
|
|
+ # W: (EMBED_SIZE, 1)
|
|
|
+ # b: (MAX_TIMESTEPS, 1)
|
|
|
+ # u: (MAX_TIMESTEPS, MAX_TIMESTEPS)
|
|
|
+ self.W = self.add_weight(name="W_{:s}".format(self.name),
|
|
|
+ shape=(input_shape[-1], 1),
|
|
|
+ initializer="normal")
|
|
|
+ self.b = self.add_weight(name="b_{:s}".format(self.name),
|
|
|
+ shape=(input_shape[1], 1),
|
|
|
+ initializer="zeros")
|
|
|
+ self.u = self.add_weight(name="u_{:s}".format(self.name),
|
|
|
+ shape=(input_shape[1], input_shape[1]),
|
|
|
+ initializer="normal")
|
|
|
+ super(Attention, self).build(input_shape)
|
|
|
+
|
|
|
+ def call(self, x, mask=None):
|
|
|
+ # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
|
|
|
+ # et: (BATCH_SIZE, MAX_TIMESTEPS)
|
|
|
+ et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=-1)
|
|
|
+ # at: (BATCH_SIZE, MAX_TIMESTEPS)
|
|
|
+ at = K.dot(et, self.u)
|
|
|
+ at = K.exp(at)
|
|
|
+ if mask is not None:
|
|
|
+ at *= K.cast(mask, K.floatx())
|
|
|
+ # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
|
|
|
+ at /= K.cast(K.sum(at, axis=1, keepdims=True) + K.epsilon(), K.floatx())
|
|
|
+ atx = K.expand_dims(at, axis=-1)
|
|
|
+ ot = atx * x
|
|
|
+ # output: (BATCH_SIZE, EMBED_SIZE)
|
|
|
+ return K.sum(ot, axis=1)
|
|
|
+
|
|
|
+ def compute_mask(self, input, input_mask=None):
|
|
|
+ # do not pass the mask to the next layers
|
|
|
+ return None
|
|
|
+
|
|
|
+ def compute_output_shape(self, input_shape):
|
|
|
+ # output shape: (BATCH_SIZE, EMBED_SIZE)
|
|
|
+ return (input_shape[0], input_shape[-1])
|
|
|
+
|
|
|
+ def get_config(self):
|
|
|
+ return super(Attention, self).get_config()
|
|
|
+
|
|
|
+
|
|
|
+ self.model = models.load_model(model_path,
|
|
|
+ custom_objects={'precision': precision,
|
|
|
+ 'recall': recall,
|
|
|
+ 'f1_score': f1_score,
|
|
|
+ 'Attention': Attention})
|
|
|
+ log.info('init model end --')
|
|
|
+
|
|
|
+ pk_path = os.path.abspath('pkl_csv')
|
|
|
+ with open(pk_path + '/id2label.pkl', 'rb') as f: # '/label_mapping210.pkl' 改为 id2label.pkl 20201023
|
|
|
+ self.label_map = pickle.load(f)
|
|
|
+ print('load label_map done')
|
|
|
+ with open(pk_path + '/bidi_classify_stop_words.csv', 'r', encoding='utf-8') as f:
|
|
|
+ self.stopwords = [row[0] for row in csv.reader(f)]
|
|
|
+ with open(pk_path + '/word_index_955871.pk', 'rb') as f:
|
|
|
+ self.word_index = pickle.load(f)
|
|
|
+ with open(pk_path + '/class2dalei_menlei.pkl', 'rb') as f: # class_subclass_dic211.pk 改为 class2dalei_menlei.pkl 20201023
|
|
|
+ self.class_dic = pickle.load(f)
|
|
|
+ log.info('classs init done ----')
|
|
|
+
|
|
|
+ def evaluate(self, text):
|
|
|
+ # 去除html标签
|
|
|
+ text = re.sub('\s', '', str(text))
|
|
|
+ text = re.sub('<\s*script[^>]*>.*?<\s*/\s*script\s*>', '', text)
|
|
|
+ text = re.sub('<\s*style[^>]*>.*?<\s*/\s*style\s*>', '', text)
|
|
|
+ text = re.sub('</?\w+[^>]*>', '', text)
|
|
|
+ # 清除干扰字符(英文、日期、数字、标点符号), 返回前500字
|
|
|
+ text = re.sub('\{.*font.*\}|\{.*Font.*\}|[^\u4e00-\u9fa5]', '', text)[:500]
|
|
|
+ # hanlp分词
|
|
|
+ result = self.hanlp.segment(text)
|
|
|
+ text_list = [str(result.get(i)) for i in range(result.size())]
|
|
|
+ # 过滤停用词
|
|
|
+ #text_list = [word for word in text_list if word not in self.stopwords and len(word) > 1] # 取消停用词过滤 20201023
|
|
|
+ # 顺序去重
|
|
|
+ #l2 = []
|
|
|
+ #[l2.append(i) for i in text_list if i not in l2] # 取消顺序去重 20201023
|
|
|
+ # 数字化
|
|
|
+ text_list = [str(self.word_index.get(word, 0)) for word in text_list] # l2 改为text_list 20201023
|
|
|
+ # padding and trans to array
|
|
|
+ text_list = text_list[:150] if len(text_list) > 150 else text_list + ['0'] * (150 - len(text_list)) # 由原来100个词改为150个词 20201023
|
|
|
+ features = np.array([text_list[:150] if len(text_list) > 150 else text_list + [0] * (150 - len(text_list))]) # 由原来100个词改为150个词 20201023
|
|
|
+ log.info('数字化结束-------------------')
|
|
|
+ # features = np.array([s.split(',')[:100] if len(s.split(','))>100 else s.split(',')+[0]*(100-len(s.split(',')))])
|
|
|
+ with tf.get_default_graph().as_default():
|
|
|
+ log.info('准备预测-------------------')
|
|
|
+ logits = self.model.predict(features)
|
|
|
+ # return ','.join(logits[0])
|
|
|
+ # result = self.label_map(np.argmax(logits[0]))
|
|
|
+ # return result
|
|
|
+ log.info('预测结束-------------------')
|
|
|
+ top3 = np.argsort(-logits[0], axis=-1)[:3]
|
|
|
+ prob = ['%.4f' % (logits[0][i]) for i in top3]
|
|
|
+ pre = [self.label_map[i] for i in top3]
|
|
|
+ rd = {}
|
|
|
+ i = 1
|
|
|
+ for a in pre:
|
|
|
+ sub, father = self.class_dic[a].split(',')
|
|
|
+ rd['top' + str(i)] = {'subclass': sub, 'class_name': a, 'class': father}
|
|
|
+ i += 1
|
|
|
+
|
|
|
+ log.info('准备返回字符串')
|
|
|
+ return json.dumps(rd,ensure_ascii=False)
|