codeClassify.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. '''
  2. Created on 2019年4月12日
  3. @author: User
  4. '''
  5. import BiddingKG.dl.interface.Preprocessing as Preprocessing
  6. from bs4 import BeautifulSoup
  7. import os
  8. import sys
  9. import codecs
  10. sys.path.append(os.path.abspath("../.."))
  11. from keras import models
  12. from keras import layers
  13. from keras_contrib.layers import CRF
  14. from keras.preprocessing.sequence import pad_sequences
  15. from keras import optimizers,losses,metrics
  16. from BiddingKG.dl.common.Utils import *
  17. import glob
  18. import pandas as pd
  19. def getArticle(article):
  20. return Preprocessing.segment(Preprocessing.tableToText(BeautifulSoup(article,"lxml")))
  21. #编号名称模型
  22. class CodeNamePredict():
  23. def __init__(self,EMBED_DIM=None,BiRNN_UNITS=None):
  24. self.model = None
  25. self.MAX_LEN = None
  26. if EMBED_DIM is None:
  27. self.EMBED_DIM = 60
  28. else:
  29. self.EMBED_DIM = EMBED_DIM
  30. if BiRNN_UNITS is None:
  31. self.BiRNN_UNITS = 40
  32. else:
  33. self.BiRNN_UNITS = BiRNN_UNITS
  34. self.filepath = "../projectCode/models/model_project_"+str(self.EMBED_DIM)+"_"+str(self.BiRNN_UNITS)+".hdf5"
  35. vocabpath = "../projectCode/models/vocab.pk"
  36. classlabelspath = "../projectCode/models/classlabels.pk"
  37. self.vocab = load(vocabpath)
  38. self.class_labels = load(classlabelspath)
  39. #生成提取编号和名称的正则
  40. id_PC_B = self.class_labels.index("PC_B")
  41. id_PC_M = self.class_labels.index("PC_M")
  42. id_PC_E = self.class_labels.index("PC_E")
  43. id_PN_B = self.class_labels.index("PN_B")
  44. id_PN_M = self.class_labels.index("PN_M")
  45. id_PN_E = self.class_labels.index("PN_E")
  46. self.PC_pattern = re.compile(str(id_PC_B)+str(id_PC_M)+"+"+str(id_PC_E))
  47. self.PN_pattern = re.compile(str(id_PN_B)+str(id_PN_M)+"+"+str(id_PN_E))
  48. self.word2index = dict((w,i) for i,w in enumerate(np.array(self.vocab)))
  49. def getModel(self):
  50. '''
  51. @summary: 取得编号和名称模型
  52. '''
  53. if self.model is None:
  54. self.model = self.getBiLSTMCRFModel(self.MAX_LEN, self.vocab, self.EMBED_DIM, self.BiRNN_UNITS, self.class_labels,weights=None)
  55. self.model.load_weights(self.filepath)
  56. return self.model
  57. def getBiLSTMCRFModel(self,MAX_LEN,vocab,EMBED_DIM,BiRNN_UNITS,chunk_tags,weights):
  58. '''
  59. model = models.Sequential()
  60. model.add(layers.Embedding(len(vocab), EMBED_DIM, mask_zero=True)) # Random embedding
  61. model.add(layers.Bidirectional(layers.LSTM(BiRNN_UNITS // 2, return_sequences=True)))
  62. crf = CRF(len(chunk_tags), sparse_target=True)
  63. model.add(crf)
  64. model.summary()
  65. model.compile('adam', loss=crf.loss_function, metrics=[crf.accuracy])
  66. return model
  67. '''
  68. input = layers.Input(shape=(None,))
  69. if weights is not None:
  70. embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True,weights=[weights],trainable=True)(input)
  71. else:
  72. embedding = layers.embeddings.Embedding(len(vocab),EMBED_DIM,mask_zero=True)(input)
  73. bilstm = layers.Bidirectional(layers.LSTM(BiRNN_UNITS//2,return_sequences=True))(embedding)
  74. bilstm_dense = layers.TimeDistributed(layers.Dense(len(chunk_tags)))(bilstm)
  75. crf = CRF(len(chunk_tags),sparse_target=True)
  76. crf_out = crf(bilstm_dense)
  77. model = models.Model(input=[input],output = [crf_out])
  78. model.summary()
  79. model.compile(optimizer = 'adam', loss = crf.loss_function, metrics = [crf.accuracy])
  80. return model
  81. #根据规则补全编号或名称两边的符号
  82. def fitDataByRule(self,data):
  83. symbol_dict = {"(":")",
  84. "(":")",
  85. "[":"]",
  86. "【":"】",
  87. ")":"(",
  88. ")":"(",
  89. "]":"[",
  90. "】":"【"}
  91. leftSymbol_pattern = re.compile("[\((\[【]")
  92. rightSymbol_pattern = re.compile("[\))\]】]")
  93. leftfinds = re.findall(leftSymbol_pattern,data)
  94. rightfinds = re.findall(rightSymbol_pattern,data)
  95. result = data
  96. if len(leftfinds)+len(rightfinds)==0:
  97. return data
  98. elif len(leftfinds)==len(rightfinds):
  99. return data
  100. elif abs(len(leftfinds)-len(rightfinds))==1:
  101. if len(leftfinds)>len(rightfinds):
  102. if symbol_dict.get(data[0]) is not None:
  103. result = data[1:]
  104. else:
  105. #print(symbol_dict.get(leftfinds[0]))
  106. result = data+symbol_dict.get(leftfinds[0])
  107. else:
  108. if symbol_dict.get(data[-1]) is not None:
  109. result = data[:-1]
  110. else:
  111. result = symbol_dict.get(rightfinds[0])+data
  112. return result
  113. def predict(self,articles,list_data,set_data,filename,MAX_LEN = None):
  114. sentences = []
  115. for article in articles:
  116. for sentence in article.split("。"):
  117. sentences.append([sentence,"1"])
  118. if MAX_LEN is None:
  119. sent_len = [len(sentence[0]) for sentence in sentences]
  120. MAX_LEN = max(sent_len)
  121. #print(MAX_LEN)
  122. #若为空,则直接返回空
  123. result = []
  124. if MAX_LEN==0:
  125. for article in articles:
  126. result.append(["1",{"code":[],"name":""}])
  127. return result
  128. index_unk = self.word2index.get("<unk>")
  129. index_pad = self.word2index.get("<pad>")
  130. x = [[self.word2index.get(word,index_unk)for word in sentence[0]]for sentence in sentences]
  131. x = pad_sequences(x,maxlen=MAX_LEN,padding="post",truncating="post")
  132. predict_y = self.getModel().predict(x)
  133. last_doc_id = ""
  134. item = []
  135. for sentence,predict in zip(sentences,np.argmax(predict_y,-1)):
  136. pad_sentence = sentence[0][:MAX_LEN]
  137. #print("11",pad_sentence)
  138. doc_id = sentence[1]
  139. join_predict = "".join([str(s) for s in predict])
  140. if doc_id!=last_doc_id:
  141. if last_doc_id!="":
  142. result.append(item)
  143. item = [doc_id,{"code":[],"name":""}]
  144. code_set = set()
  145. for iter in re.finditer(self.PC_pattern,join_predict):
  146. get_len = 40
  147. if iter.span()[0]<get_len:
  148. begin = 0
  149. else:
  150. begin = iter.span()[0]-get_len
  151. end = iter.span()[1]+get_len
  152. if (pad_sentence[begin:iter.span()[0]],pad_sentence[iter.span()[0]:iter.span()[1]],pad_sentence[iter.span()[1]:end]) in set_data:
  153. continue
  154. set_data.add((pad_sentence[begin:iter.span()[0]],pad_sentence[iter.span()[0]:iter.span()[1]],pad_sentence[iter.span()[1]:end]))
  155. list_data.append([pad_sentence[begin:iter.span()[0]],pad_sentence[iter.span()[0]:iter.span()[1]],pad_sentence[iter.span()[1]:end],filename])
  156. the_code = self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  157. if the_code not in code_set:
  158. code_set.add(the_code)
  159. item[1]['code'] = list(code_set)
  160. if item[1]['name']=="":
  161. for iter in re.finditer(self.PN_pattern,join_predict):
  162. #item[1]['name']=item[1]['name']+";"+self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  163. item[1]['name']=self.fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]])
  164. break
  165. last_doc_id = doc_id
  166. result.append(item)
  167. return result
  168. def generateData():
  169. home = ["C:\\Users\\User\\Desktop\\20190306要素\\*.html","C:\\Users\\User\\Desktop\\20190320要素\\*.html"]
  170. predictor = CodeNamePredict()
  171. list_data = []
  172. set_data = set()
  173. count = 0
  174. for h in home:
  175. for file in glob.glob(h):
  176. count += 1
  177. print("count",count)
  178. content = codecs.open(file,"r",encoding="utf8").read()
  179. filename = file.split("\\")[-1]
  180. predictor.predict([getArticle(content)],list_data,set_data,filename)
  181. list_before = []
  182. list_code = []
  183. list_after = []
  184. list_filename = []
  185. for item in list_data:
  186. list_filename.append(item[3])
  187. list_before.append(item[0])
  188. list_code.append(item[1])
  189. list_after.append(item[2])
  190. df = pd.DataFrame({"list_filename":list_filename,"list_before":list_before,"list_code":list_code,"list_after":list_after})
  191. columns = ["list_filename","list_before","list_code","list_after"]
  192. df.to_excel("code.xls",index=False,columns=columns)
  193. def selectWithRule(source,filter,target):
  194. assert source!=target
  195. dict_source = pd.read_excel(source)
  196. set_filter = set()
  197. for filt in filter:
  198. set_filter = set_filter | set(pd.read_excel(filt)["list_index"])
  199. list_index = []
  200. list_filename = []
  201. list_before = []
  202. list_text = []
  203. list_after = []
  204. for id,filename,before,text,after in zip(dict_source["list_index"],dict_source["list_filename"],dict_source["list_before"],dict_source["list_code"],dict_source["list_after"]):
  205. if id in set_filter:
  206. continue
  207. if re.search("().{,3}",str(before)) is not None:
  208. list_index.append(id)
  209. list_filename.append(filename)
  210. list_before.append(before)
  211. list_text.append(text)
  212. list_after.append(after)
  213. data = {"list_index":list_index,"list_filename":list_filename,"list_before":list_before,"list_code":list_text,"list_after":list_after}
  214. columns = ["list_index","list_filename","list_before","list_code","list_after","list_label","list_prob"]
  215. df = pd.DataFrame(data)
  216. df.to_excel(target,index=False,columns=columns)
  217. if __name__=="__main__":
  218. #generateData()
  219. selectWithRule("code.xls",["批量.xls"],"剩余手工标注.xls")