ProjectCodeNameRecognition.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../../.."))
  4. import gensim
  5. import numpy as np
  6. from keras import models
  7. from keras import layers
  8. from keras.callbacks import ModelCheckpoint
  9. import pickle
  10. from projectLabel import *
  11. import re
  12. from BiddingKG.dl.common.Connection import getConnection
  13. from models import *
  14. import tensorflow as tf
  15. #不使用gpu加速
  16. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  17. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  18. def save(object_to_save, path):
  19. '''
  20. 保存对象
  21. @Arugs:
  22. object_to_save: 需要保存的对象
  23. @Return:
  24. 保存的路径
  25. '''
  26. with open(path, 'wb') as f:
  27. pickle.dump(object_to_save, f)
  28. def load(path):
  29. '''
  30. 读取对象
  31. @Arugs:
  32. path: 读取的路径
  33. @Return:
  34. 读取的对象
  35. '''
  36. with open(path, 'rb') as f:
  37. object = pickle.load(f)
  38. return object
  39. def EmbeddingVocabAndMatrix(Embedding_size=60):
  40. '''
  41. @summary:获取子向量的词典和子向量矩阵
  42. '''
  43. singlew2v_file = "../../dl/singlew2v_model.vector"
  44. singlew2v_model = gensim.models.KeyedVectors.load_word2vec_format(singlew2v_file, binary=True)
  45. vocab = ['<pad>', '<unk>'] + singlew2v_model.index2word
  46. embedding_matrix = np.zeros((len(vocab), Embedding_size))
  47. for i in range(2, len(vocab)):
  48. embedding_matrix[i] = singlew2v_model[vocab[i]]
  49. return vocab, embedding_matrix
  50. from sklearn.metrics import accuracy_score
  51. def getAcc(y_batch,logits,trans,lengths):
  52. index = 0
  53. small = -1000.0
  54. start = np.asarray([[small] * 7 + [0]])
  55. preds = []
  56. true_tags = []
  57. for score, length in zip(logits, lengths):
  58. score = score[:length]
  59. # pad = small * np.ones([length, 1])
  60. # logit = np.concatenate([score, pad], axis=1)
  61. # logit = np.concatenate([start, logit], axis=0)
  62. # path, _ = tf.contrib.crf.viterbi_decode(logit, trans)
  63. path, _ = viterbi_decode(score, trans)
  64. preds += path[0:]
  65. # preds += path[1:]
  66. index += 1
  67. for y, length in zip(y_batch, lengths):
  68. y = y.tolist()
  69. true_tags += y[: length]
  70. acc = accuracy_score(np.reshape(true_tags,(-1)), np.reshape(preds,(-1)))
  71. return acc
  72. from BiddingKG.dl.common.Utils import viterbi_decode
  73. def decode(logits, trans, sequence_lengths, tag_num):
  74. viterbi_sequences = []
  75. small = -1000.0
  76. start = np.asarray([[small] * tag_num + [0]])
  77. for logit, length in zip(logits, sequence_lengths):
  78. score = logit[:length]
  79. # pad = small * np.ones([length, 1])
  80. # score = np.concatenate([score, pad], axis=1)
  81. # score = np.concatenate([start, score], axis=0)
  82. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  83. viterbi_sequences.append(viterbi_seq[1:])
  84. return viterbi_sequences
  85. def training():
  86. MAX_LEN = 300
  87. train = True
  88. EMBED_DIM = 60
  89. BiRNN_UNITS = 128
  90. filepath = "model_project_" + str(EMBED_DIM) + "_" + str(BiRNN_UNITS) + ".hdf5"
  91. vocabpath = "vocab.pk"
  92. classlabelspath = "classlabels.pk"
  93. usersinglew2v = True
  94. if usersinglew2v:
  95. singlew2v_vocab, singlew2v_matrix = EmbeddingVocabAndMatrix()
  96. else:
  97. singlew2v_vocab, singlew2v_matrix = None, None
  98. if train:
  99. ''''''
  100. print("training:")
  101. (train_x, train_y,train_len), (test_x, test_y,test_len), (vocab, class_labels), test = generateDatas(MAX_LEN=MAX_LEN,
  102. vocab_set=singlew2v_vocab)
  103. save(vocab, vocabpath)
  104. save(class_labels, classlabelspath)
  105. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
  106. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options),graph=tf.Graph())
  107. with sess:
  108. char_input,logits,target,length,keepprob,crf_loss,trans,train_op = getBilstmCRF_tf(sess,MAX_LEN, vocab, EMBED_DIM, BiRNN_UNITS, class_labels, weights=singlew2v_matrix)
  109. sess.run(tf.global_variables_initializer())
  110. epochs = 300
  111. saver = tf.train.Saver(max_to_keep=epochs)
  112. batch_size = 400
  113. saver.restore(sess,"model/113-0.705722005308-6.94006/model.ckpt")
  114. print(sess.run(trans))
  115. # _logits,_trans = sess.run([logits,trans],feed_dict={char_input:test_x,target:test_y,length:test_len,keepprob:1.0})
  116. # viterbi_sequence = decode(logits=_logits,trans=_trans,sequence_lengths=test_len,tag_num=7)
  117. # for _seq in viterbi_sequence:
  118. # print(_seq)
  119. # for _i in range(epochs):
  120. # size_train_x = np.shape(train_x)[0]
  121. # _batch = 0
  122. # while(_batch<size_train_x):
  123. # _x = train_x[_batch:_batch+batch_size]
  124. # _y = train_y[_batch:_batch+batch_size]
  125. # _length = np.array(train_len[_batch:_batch+batch_size])
  126. # _batch = _batch+batch_size
  127. # _loss,_ = sess.run([crf_loss,train_op],feed_dict={char_input:_x,target:_y,length:_length,keepprob:0.85})
  128. # # print(_i,_loss)
  129. # _logits,_trans = sess.run([logits,trans],feed_dict={char_input:test_x,target:test_y,length:test_len,keepprob:1.0})
  130. # print(np.shape(_logits),np.shape(test_y))
  131. # acc = getAcc(test_y, _logits, _trans, test_len)
  132. # print("---test","epochs:",_i,acc)
  133. # print("============","epochs:",str(_i),"loss:",str(_loss)+"done")
  134. # saver.save(sess,'model/'+str(_i)+"-"+str(acc)+"-"+str(_loss)+'/model.ckpt')
  135. model = getBiLSTMCRFModel(MAX_LEN, vocab, EMBED_DIM, BiRNN_UNITS, class_labels, weights=singlew2v_matrix)
  136. filepath = "log/ep300-acc0.923-loss-0.443-val_loss-0.404-f1_score0.916.h5"
  137. model.load_weights(filepath)
  138. print(tf.trainable_variables())
  139. print(sess.run(sess.graph.get_tensor_by_name("crf_1/chain_kernel:0")))
  140. for item in np.argmax(model.predict(test_x),-1)[:20]:
  141. print(item)
  142. return
  143. checkpoint = ModelCheckpoint(
  144. filepath="log/" + "ep{epoch:03d}-acc{acc:.3f}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_acc:.3f}.h5",
  145. monitor="val_loss", verbose=1, save_best_only=False, save_weights_only=True)
  146. model.fit(train_x, np.expand_dims(train_y, 2), epochs=300, batch_size=400, validation_data=[test_x, np.expand_dims(test_y, 2)],
  147. callbacks=[checkpoint])
  148. else:
  149. print("predicting")
  150. singlew2v_vocab = load(vocabpath)
  151. class_labels = load(classlabelspath)
  152. singlew2v_matrix = None
  153. (train_x, train_y), (test_x, test_y), (vocab, class_labels), test = generateDatas(MAX_LEN=MAX_LEN,
  154. vocab_set=singlew2v_vocab)
  155. model = getBiLSTMCRFModel(MAX_LEN, vocab, EMBED_DIM, BiRNN_UNITS, class_labels, weights=singlew2v_matrix)
  156. model.load_weights(filepath)
  157. # # print(train_x[0],train_y[0])
  158. # predict_y = model.predict(test_x)
  159. # predict_y1 = np.argmax(predict_y, -1)
  160. # print(predict_y1[0])
  161. # with codecs.open("predict_test.txt", "w", encoding="utf8") as f:
  162. # for t in range(len(test)):
  163. # content = test[t]
  164. # if len(content) > MAX_LEN:
  165. # range_len = MAX_LEN
  166. # else:
  167. # range_len = len(content)
  168. # for h in range(range_len):
  169. # f.write(str(test[t][h][0]) + " " + str(test[t][h][1]) + " " + str(class_labels[predict_y1[t][h]]))
  170. # f.write("\n")
  171. # f.write("\n")
  172. # f.flush()
  173. def fitDataByRule(data):
  174. '''
  175. @summary:根据规则补全编号或者名称前后的符号
  176. '''
  177. symbol_dict = {"(": ")",
  178. "(": ")",
  179. "[": "]",
  180. "【": "】",
  181. ")": "(",
  182. ")": "(",
  183. "]": "[",
  184. "】": "【"}
  185. leftSymbol_pattern = re.compile("[\((\[【]")
  186. rightSymbol_pattern = re.compile("[\))\]】]")
  187. leftfinds = re.findall(leftSymbol_pattern, data)
  188. rightfinds = re.findall(rightSymbol_pattern, data)
  189. result = data
  190. if len(leftfinds) + len(rightfinds) == 0:
  191. return data
  192. elif len(leftfinds) == len(rightfinds):
  193. return data
  194. elif abs(len(leftfinds) - len(rightfinds)) == 1:
  195. if len(leftfinds) > len(rightfinds):
  196. if symbol_dict.get(data[0]) is not None:
  197. result = data[1:]
  198. else:
  199. print(symbol_dict.get(leftfinds[0]))
  200. result = data + symbol_dict.get(leftfinds[0])
  201. else:
  202. if symbol_dict.get(data[-1]) is not None:
  203. result = data[:-1]
  204. else:
  205. result = symbol_dict.get(rightfinds[0]) + data
  206. return result
  207. def predicting(articles, MAX_LEN=None):
  208. '''
  209. @summary:预测文章里的编号和名称
  210. @param:
  211. articles:文章list
  212. '''
  213. print("predicting")
  214. # ckpt_file = "codename_savedmodel"
  215. ckpt_file = "codename_savedmodel_bilstmcrf"
  216. sess = tf.Session(graph=tf.Graph())
  217. with sess.as_default():
  218. meta_graph = tf.saved_model.loader.load(sess, tags=["serve"], export_dir=ckpt_file)
  219. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  220. signature_def = meta_graph.signature_def
  221. inputs = sess.graph.get_tensor_by_name(signature_def[signature_key].inputs["inputs"].name)
  222. outputs = sess.graph.get_tensor_by_name(signature_def[signature_key].outputs["outputs"].name)
  223. vocabpath = "models/vocab.pk"
  224. classlabelspath = "models/classlabels.pk"
  225. vocab = load(vocabpath)
  226. class_labels = load(classlabelspath)
  227. sentences = []
  228. for article in articles:
  229. if article[0] != "33ee0f51-7a03-11e8-a4b1-44a84246dbba":
  230. continue
  231. for sentence in re.split("[,;。!\n]", article[1]):
  232. print(sentence)
  233. sentences.append([sentence, article[0], article[2] if article[2] is not None else "",
  234. article[3] if article[3] is not None else ""])
  235. if MAX_LEN is None:
  236. sent_len = [len(sentence[0]) for sentence in sentences]
  237. MAX_LEN = max(sent_len)
  238. print(MAX_LEN)
  239. word2index = dict((w, i) for i, w in enumerate(np.array(vocab)))
  240. index_unk = word2index.get("<unk>")
  241. index_pad = word2index.get("<pad>")
  242. x = [[word2index.get(word, index_unk) for word in sentence[0]] for sentence in sentences]
  243. x = pad_sequences(x, maxlen=MAX_LEN, padding="post", truncating="post")
  244. MAX_LEN = 300
  245. train = True
  246. EMBED_DIM = 60
  247. BiRNN_UNITS = 128
  248. model = getBiLSTMCRFModel(MAX_LEN, vocab, EMBED_DIM, BiRNN_UNITS, class_labels, weights=None)
  249. model.load_weights("log/ep024-acc0.994-loss0.016-val_loss0.022-f1_score0.992.h5")
  250. # predict_y = limitRun(sess,[outputs],feed_dict={inputs:x},MAX_BATCH=1)[0]
  251. predict_y = model.predict(x)
  252. id_PC_B = class_labels.index("PC_B")
  253. id_PC_M = class_labels.index("PC_M")
  254. id_PC_E = class_labels.index("PC_E")
  255. id_PN_B = class_labels.index("PN_B")
  256. id_PN_M = class_labels.index("PN_M")
  257. id_PN_E = class_labels.index("PN_E")
  258. PC_pattern = re.compile(str(id_PC_B) + str(id_PC_M) + "+" + str(id_PC_E))
  259. PN_pattern = re.compile(str(id_PN_B) + str(id_PN_M) + "+" + str(id_PN_E))
  260. result = []
  261. last_doc_id = ""
  262. item = []
  263. for sentence, predict in zip(sentences, np.argmax(predict_y, -1)):
  264. pad_sentence = sentence[0][:MAX_LEN]
  265. doc_id = sentence[1]
  266. join_predict = "".join([str(s) for s in predict])
  267. if doc_id != last_doc_id:
  268. if last_doc_id != "":
  269. result.append(item)
  270. item = [doc_id, set(), set(), sentence[2], sentence[3]]
  271. for iter in re.finditer(PC_pattern, join_predict):
  272. item[1].add(fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]]))
  273. for iter in re.finditer(PN_pattern, join_predict):
  274. # item[2]=item[2]+";"+pad_sentence[iter.span()[0]:iter.span()[1]]
  275. item[2].add(fitDataByRule(pad_sentence[iter.span()[0]:iter.span()[1]]))
  276. last_doc_id = doc_id
  277. result.append(item)
  278. print(result[0])
  279. with codecs.open("predictPCPN.html", "w", encoding="utf8") as f:
  280. f.write('<html><head>\
  281. <meta http-equiv="Content-Type"\
  282. content="text/html; charset=UTF-8">\
  283. </head>\
  284. <body bgcolor="#FFFFFF">\
  285. <table border="1">\
  286. <tr>\
  287. <td>doc_id</td>\
  288. <td>编号</td>\
  289. <td>名称</td>\
  290. </tr>')
  291. for item in result:
  292. f.write("<tr>" + "<td>" + item[0] + "</td>" + "<td>" + str(item[1]) + "</td>" + "<td>" + str(
  293. item[2]) + "</td>" + "<td>" + item[3] + "</td>" + "<td>" + item[4] + "</td>" + "</tr>")
  294. f.write("</table></body>")
  295. def getPredictArticles():
  296. conn = getConnection()
  297. cursor = conn.cursor()
  298. # sql = " select id,content,project_compare.projectcode,project_compare.projectname from articles_processed left join project_compare on id=doc_id where id in(select distinct A.doc_id from entity_mention A,test_predict_role B where A.entity_id=B.entity_id limit 200) order by id"
  299. sql = " select id,content,code,name from articles_processed A,articles_validation B where A.id=B.doc_id "
  300. cursor.execute(sql)
  301. rows = cursor.fetchall()
  302. return rows
  303. def get_savedmodel():
  304. MAX_LEN = 300
  305. EMBED_DIM = 60
  306. BiRNN_UNITS = 128
  307. vocabpath = "models/vocab.pk"
  308. classlabelspath = "models/classlabels.pk"
  309. vocab = load(vocabpath)
  310. class_labels = load(classlabelspath)
  311. with tf.Session(graph=tf.Graph()).as_default() as sess:
  312. with sess.graph.as_default():
  313. model = getBiLSTMCRFModel(MAX_LEN, vocab, EMBED_DIM, BiRNN_UNITS, class_labels, weights=None)
  314. filepath = "log/ep034-acc0.956-loss0.112-val_loss0.112-f1_score0.956.h5"
  315. model.load_weights(filepath)
  316. print(tf.trainable_variables())
  317. print(sess.run(sess.graph.get_tensor_by_name("crf_1/chain_kernel:0"),feed_dict={model.input:np.array([[1,2,3,4,5,6,0,0,0],[2,3,4,5,0,0,0,0,0]])}))
  318. # tf.saved_model.simple_save(sess,
  319. # "./codename_savedmodel/",
  320. # inputs={"inputs": model.input},
  321. # outputs={"outputs": model.output})
  322. # tf.summary.FileWriter(logdir="log1",graph=sess.graph)
  323. if __name__ == "__main__":
  324. training()
  325. # predicting(getPredictArticles())
  326. # get_savedmodel()