data_util.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # encoding=utf-8
  2. import re
  3. import pickle
  4. import gensim
  5. import numpy as np
  6. import pandas as pd
  7. from pyhanlp import *
  8. import keras.backend as K
  9. from keras.preprocessing.sequence import pad_sequences
  10. def load(path):
  11. '''
  12. pickle 加载pkl 文件
  13. '''
  14. with open(path, 'rb') as f:
  15. return pickle.load(f)
  16. def get_remove_word():
  17. '''
  18. 加载停用词、不重要的词
  19. '''
  20. stopwords_path = 'pickle_1/bidi_classify_stop_words.csv' # 停用词文件 
  21. df_stopwords = pd.read_csv(stopwords_path)
  22. remove_word = df_stopwords['stopword'].values.tolist()
  23. return remove_word
  24. def get_embedding():
  25. '''
  26. 加载文件,返回词典、keras tokennizer对象,词向量矩阵
  27. '''
  28. word_index = load('pickle_1/word_index_955871.pk') #加载词典文件 word:id
  29. tokenizer = load('pickle_1/tokenizer_955871.pk') # 加载训练后keras tokenizer对象
  30. w2v_model_path = 'model/thr_100_model.vector' # 加载词向量文件
  31. w2v_model = gensim.models.KeyedVectors.load_word2vec_format(w2v_model_path,binary=True)
  32. embedding_matrix = np.random.random((len(word_index) + 1, 100))
  33. count_not_in_model = 0
  34. count_in_model = 0
  35. for word, i in word_index.items():
  36. if word in w2v_model:
  37. count_in_model += 1
  38. embedding_matrix[i] = np.asarray(w2v_model[word], dtype='float32')
  39. else:
  40. count_not_in_model += 1
  41. return word_index, tokenizer, embedding_matrix
  42. def get_label():
  43. '''
  44. 加载标签字典,返回字典label_mapping {0: '安防系统', 1: '安全保护服务', 2: '安全保护设备' ; labels10 所有类别的中文名称
  45. '''
  46. label_mapping = load('pickle_1/label_mapping_f.pk')
  47. labels10 = list(label_mapping.values())
  48. return label_mapping,labels10
  49. def get_dic():
  50. '''
  51. 加载类别字典,估计是新旧类别: 豆类、油料和薯类种植': '农业,农、林、牧、渔业', '蔬菜、食用菌及园艺作物种植': '农业,农、林、牧、渔业'
  52. '''
  53. dic_label_path = 'pickle_1/class_subclass_dic211.pk'
  54. dic_label = load(dic_label_path)
  55. return dic_label
  56. def model_in(r1, label_mapping, id):
  57. '''
  58. 获取每个文章的中文类别名称
  59. @Argus: r1:np.array 预测结果 ; label_mapping:分类类别字典 0: '安防系统
  60. @Return:中文分类名称
  61. '''
  62. all_end = r1
  63. aa2 = []
  64. for i in range(all_end.shape[0]):
  65. c1 = label_mapping[np.argmax(all_end[i])]
  66. aa2.append(c1)
  67. union = []
  68. for x in range(len(id)):
  69. union.append([id[x],aa2[x]])
  70. return union
  71. def convertJlistToPlist(jList):
  72. '''
  73. 将javaList 转为pythonlist
  74. '''
  75. ret = []
  76. if jList is None:
  77. return ret
  78. for i in range(jList.size()):
  79. ret.append(str(jList.get(i)))
  80. return ret
  81. def clean_RmWord(text, remove_word):
  82. '''
  83. 去除没用的词语
  84. '''
  85. text_copy = text.copy()
  86. for i in text:
  87. if i in remove_word:
  88. text_copy.remove(i)
  89. text_copy = " ".join(text_copy)
  90. return text_copy
  91. def handle_doc1(article_set10_1, remove_word):
  92. '''
  93. 句子分词并删除单字、重复、无关词语
  94. @Argus: article_set10_1: 包含待处理字符串的Series
  95. @Return: 处理后的结果
  96. '''
  97. HanLP.Config = JClass('com.hankcs.hanlp.HanLP$Config')
  98. HanLP.Config.ShowTermNature = False
  99. article_set10_seg_1 = article_set10_1.map(lambda x: convertJlistToPlist(HanLP.segment(x)))
  100. article_set10_seg_1 = article_set10_seg_1.map(lambda x: ' '.join(word for word in x if len(word) > 1)) # 删除单个字
  101. article_set10_seg_rm = article_set10_seg_1.map(lambda x: clean_RmWord(x.split(), remove_word)) # 删除无用、重复词语
  102. article_set10_seg_rm = article_set10_seg_rm.map(lambda x: x.split())
  103. return article_set10_seg_rm
  104. def cleanSeg(text):
  105. '''
  106. 清除干扰字符(英文、日期、数字、标点符号)
  107. '''
  108. text = re.sub('[a-zA-Z]', '', text)
  109. text = text.replace('\n', ' ')
  110. text = re.sub(r"-", " ", text)
  111. text = re.sub(r"\d+/\d/\d+", "", text)
  112. text = re.sub(r"[0-2]?[0-9]:[0-6][0-9]", "", text)
  113. text = re.sub(r"[\w]+@[\.\w]+", "", text)
  114. text = re.sub(r"/[a-zA-Z]*[:\//\]*[A-Za-z0-9\-_]+\.+[A-Za-z0-9\.\/%&=\?\-_]+/i", "", text)
  115. pure_text = ''
  116. for letter in text:
  117. if letter.isalpha() or letter == ' ':
  118. pure_text += letter
  119. text = ' '.join(word for word in pure_text.split() if len(word) > 1)
  120. text = text.replace(' ', '')
  121. return text
  122. def fetch_sub_data_1(data, num):
  123. '''
  124. 获取文本前N个字符
  125. '''
  126. return data[:num]
  127. def data_set(text):
  128. '''
  129. 保持顺序词语去重
  130. '''
  131. l2 = []
  132. [l2.append(i) for i in text if i not in l2]
  133. return l2
  134. def clean_word(article_set10,remove_word):
  135. """
  136. 清理数据,清除符号、字母、数字等,统一文章长度,对句子进行分词,删除单字、重复、无关词语、停用词
  137. :param article_set10: 原数据,list
  138. :param remove_word: 停用词表,list
  139. :return: Series
  140. """
  141. article_set10_1 = pd.Series(article_set10)
  142. article_set10_1 = article_set10_1.map(lambda x: cleanSeg(x)) # 清除干扰字符(英文、日期、数字、标点符号)
  143. article_set10_1 = article_set10_1.map(lambda x: fetch_sub_data_1(x, 500)) # 获取文本前N个字符
  144. # test
  145. article_set10_seg_rm = handle_doc1(article_set10_1, remove_word) # 句子分词并删除单字、重复、无关词语
  146. # test
  147. x_train_df_10 = article_set10_seg_rm.copy()
  148. x_train_df_10 = x_train_df_10.map(lambda x: data_set(x)) # 保持顺序词语去重
  149. return x_train_df_10
  150. def clean_word_with_tokenizer(article_set10,remove_word,tokenizer):
  151. """
  152. 清理数据,清除符号、字母、数字、停用词,分词
  153. :param article_set10: 原数据,list
  154. :param remove_word: 停用词表,list
  155. :return: Series
  156. """
  157. id = [i[0] for i in article_set10]
  158. article_set10 = [i[1] for i in article_set10]
  159. article_set10_1 = pd.Series(article_set10)
  160. article_set10_1 = article_set10_1.map(lambda x: cleanSeg(x))
  161. article_set10_1 = article_set10_1.map(lambda x: fetch_sub_data_1(x, 500))
  162. # test
  163. article_set10_seg_rm = handle_doc1(article_set10_1, remove_word)
  164. # print(article_set10_seg_rm)
  165. # test
  166. x_train_df_10 = article_set10_seg_rm.copy()
  167. sequences = tokenizer.texts_to_sequences(x_train_df_10)
  168. padded_sequences = pad_sequences(sequences, maxlen=100, padding='post', truncating='post',value=0.0)
  169. # left_word = [x[:-1] for x in padded_sequences]
  170. # right_word = [x[1:] for x in padded_sequences]
  171. # left_pad = pad_sequences(left_word, maxlen=100, value=0.0)
  172. # right_pad = pad_sequences(right_word, maxlen=100, padding='post', truncating='post', value=0.0)
  173. return padded_sequences, id
  174. def recall(y_true, y_pred):
  175. '''
  176. 计算召回率
  177. @Argus:
  178. y_true: 正确的标签
  179. y_pred: 模型预测的标签
  180. @Return
  181. 召回率
  182. '''
  183. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  184. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  185. if c3 == 0:
  186. return 0
  187. recall = c1 / c3
  188. return recall
  189. def f1_score(y_true, y_pred):
  190. '''
  191. 计算F1
  192. @Argus:
  193. y_true: 正确的标签
  194. y_pred: 模型预测的标签
  195. @Return
  196. F1值
  197. '''
  198. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  199. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  200. c3 = K.sum(K.round(K.clip(y_true, 0, 1)))
  201. precision = c1 / c2
  202. if c3 == 0:
  203. recall = 0
  204. else:
  205. recall = c1 / c3
  206. f1_score = 2 * (precision * recall) / (precision + recall)
  207. return f1_score
  208. def precision(y_true, y_pred):
  209. '''
  210. 计算精确率
  211. @Argus:
  212. y_true: 正确的标签
  213. y_pred: 模型预测的标签
  214. @Return
  215. 精确率
  216. '''
  217. c1 = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
  218. c2 = K.sum(K.round(K.clip(y_pred, 0, 1)))
  219. precision = c1 / c2
  220. return precision
  221. if __name__ == '__main__':
  222. dic_label = get_dic()
  223. print(dic_label)