punishNo_tf.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import tensorflow as tf
  2. from tensorflow.contrib.crf import crf_log_likelihood
  3. from tensorflow.contrib.layers.python.layers import initializers
  4. import numpy as np
  5. import pandas as pd
  6. from zipfile import ZipFile
  7. import os
  8. import pickle
  9. from BiddingKG.dl.common.Utils import *
  10. from keras.preprocessing.sequence import pad_sequences
  11. # class BiLSTM_CRF_tf(object):
  12. # def __init__(self):
  13. def BiLSTM_CRF_tfmodel(sess,weights):
  14. BiRNN_Units = 140
  15. chunk_tags = {
  16. 'O': 0,
  17. 'PN_B': 1,
  18. 'PN_M': 2,
  19. 'PN_E': 3
  20. }
  21. def embedding_layer(input):
  22. embedding = tf.get_variable("embedding",initializer=np.array(weights,dtype=np.float32) if weights is not None else None,dtype=tf.float32)
  23. return tf.nn.embedding_lookup(params=embedding,ids=input)
  24. def BiLSTM_Layer(input,length):
  25. with tf.variable_scope("BiLSTM"):
  26. forward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Units//2,state_is_tuple=True)
  27. backward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Units//2,state_is_tuple=True)
  28. output, _ = tf.nn.bidirectional_dynamic_rnn(forward_cell,backward_cell,input,dtype=tf.float32,sequence_length=length)
  29. output = tf.concat(output,2)
  30. return output
  31. def CRF_layer(input,num_tags,BiRNN_Units,time_step):
  32. with tf.variable_scope("CRF"):
  33. with tf.variable_scope("hidden"):
  34. w_hidden = tf.get_variable(name='w_hidden',shape=(BiRNN_Units,BiRNN_Units//2),dtype=tf.float32,
  35. initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
  36. b_hidden = tf.get_variable(name='b_hidden',shape=(BiRNN_Units//2),dtype=tf.float32,initializer=tf.zeros_initializer())
  37. # print(input)
  38. input_reshape = tf.reshape(input,shape=(-1,BiRNN_Units))
  39. hidden = tf.tanh(tf.nn.xw_plus_b(input_reshape,w_hidden,b_hidden))
  40. with tf.variable_scope("output"):
  41. w_output = tf.get_variable(name='w_output',shape=(BiRNN_Units//2,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
  42. b_output = tf.get_variable(name='b_output',shape=(num_tags),dtype=tf.float32,initializer=tf.zeros_initializer())
  43. pred = tf.nn.xw_plus_b(hidden,w_output,b_output)
  44. logits_ = tf.reshape(pred,shape=(-1,time_step,num_tags),name='logits')
  45. return logits_
  46. def layer_loss(input,true_target,num_tags,length):
  47. with tf.variable_scope("crf_loss"):
  48. trans = tf.get_variable(name='transitons',shape=(num_tags,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer())
  49. log_likelihood,trans = crf_log_likelihood(inputs=input,tag_indices=true_target,transition_params=trans,sequence_lengths=length)
  50. return tf.reduce_mean(-log_likelihood),trans
  51. with sess.graph.as_default():
  52. char_input = tf.placeholder(name='char_input',shape=(None,None),dtype=tf.int32)
  53. target = tf.placeholder(name='target',shape=(None,None),dtype=tf.int32)
  54. length = tf.placeholder(name='length',shape=(None,),dtype=tf.int32)
  55. # keepprob = tf.placeholder(name='keepprob',dtype=tf.float32)
  56. _embedding = embedding_layer(char_input)
  57. _shape = tf.shape(char_input)
  58. batch_size = _shape[0]
  59. step_size = _shape[-1]
  60. bilstm = BiLSTM_Layer(_embedding,length)
  61. _logits = CRF_layer(bilstm,num_tags=len(chunk_tags),BiRNN_Units=BiRNN_Units,time_step=step_size)
  62. crf_loss,trans = layer_loss(_logits,true_target=target,num_tags=len(chunk_tags),length=length)
  63. global_step = tf.Variable(0,trainable=False)
  64. with tf.variable_scope("optimizer"):
  65. opt = tf.train.AdamOptimizer(0.002)
  66. grads_vars = opt.compute_gradients(crf_loss)
  67. capped_grads_vars = [[tf.clip_by_value(g,-5,5),v] for g,v in grads_vars]
  68. train_op = opt.apply_gradients(capped_grads_vars,global_step)
  69. return char_input,_logits,target,length,crf_loss,trans,train_op
  70. def train():
  71. vocab_model = getModel_word()
  72. vocab, w2v_matrix = getVocabAndMatrix(vocab_model, Embedding_size=60)
  73. # print(w2v_matrix)
  74. punishNo = {
  75. 'O': 0,
  76. 'PN_B': 1,
  77. 'PN_M': 2,
  78. 'PN_E': 3
  79. }
  80. punishNo_2 = {
  81. 'O': np.array([1, 0, 0, 0]),
  82. 'PN_B': np.array([0, 1, 0, 0]),
  83. 'PN_M': np.array([0, 0, 1, 0]),
  84. 'PN_E': np.array([0, 0, 0, 1])
  85. }
  86. data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\db_alldata.csv", index_col=0)
  87. train_data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\punishment_code_new.csv", index_col=0)
  88. train_data['text'] = [data['text'][data['document_id'] == id] for id in train_data['document_id']]
  89. data_x = []
  90. data_y = []
  91. articles_label = ['' for _ in range(13500)]
  92. punishNo_in_text = set()
  93. for textId, begin, end, entity_text, text in zip(train_data['document_id'], train_data['begin_index'],
  94. train_data['end_index'],
  95. train_data['entity_text'], train_data['text']):
  96. punishNo_in_text.add(textId)
  97. text = list(text)[0]
  98. l = len(text)
  99. if not articles_label[textId]:
  100. articles_label[textId] = ['O' for _ in range(l)]
  101. articles_label[textId][begin] = 'PN_B'
  102. articles_label[textId][end - 1] = 'PN_E'
  103. for i in range(begin + 1, end - 1):
  104. articles_label[textId][i] = 'PN_M'
  105. punishNo_in_text = list(punishNo_in_text)
  106. # 取含数字的负样本
  107. data = data.dropna(subset=['text'])
  108. re_rule1 = re.compile('\[|\]')
  109. data['sentences'] = [re_rule1.sub('', sentences).split(',') for sentences in data['sentences']]
  110. data['sentences'] = [[int(s) for s in sentences] for sentences in data['sentences']]
  111. re_rule2 = re.compile("[\d,.]{4,}")
  112. for id, article, sentences in zip(data['document_id'], data['text'], data['sentences']):
  113. if id < 2826 or id in punishNo_in_text:
  114. # print(id)
  115. article = str(article)
  116. l = len(article)
  117. text_word = list(article)
  118. text_word_index = [getIndexOfWord(word) for word in text_word]
  119. sentence_count = len(sentences)
  120. if articles_label[id]:
  121. label_list = articles_label[id]
  122. else:
  123. label_list = ['O' for _ in range(l)]
  124. for i in range(sentence_count - 1):
  125. if re_rule2.search(article[sentences[i]:sentences[i + 1]]):
  126. data_x.append(np.array(text_word_index[sentences[i]:sentences[i + 1]]))
  127. data_y.append(np.array(label_list[sentences[i]:sentences[i + 1]]))
  128. data_x = np.array(data_x)
  129. x_len = [250 if len(x)>250 else len(x) for x in data_x]
  130. data_x = pad_sequences(data_x, maxlen=250, padding="post", truncating="post")
  131. # train_x = train_x.reshape(-1)
  132. data_y = [np.array([punishNo[_y] for _y in y]) for y in data_y]
  133. # data_y = np.array(data_y).reshape(-1)
  134. data_y = np.array(data_y)
  135. data_y = pad_sequences(data_y, maxlen=250, padding="post", truncating="post")
  136. # print(data_x[:5])
  137. # print(data_y[:5])
  138. # data_x = np.array(list(data_x))
  139. # data_y = np.array(list(data_y))
  140. indices = np.random.permutation(data_x.shape[0])
  141. count = len(data_x)
  142. test_count = int(0.2 * count)
  143. test_idx, train_idx = indices[:test_count], indices[test_count:]
  144. # print(test_idx)
  145. train_x, test_x = data_x[train_idx, :], data_x[test_idx, :]
  146. train_y, test_y = data_y[train_idx, :], data_y[test_idx, :]
  147. train_x_len = np.array([x_len[idx] for idx in train_idx])
  148. test_x_len = np.array([x_len[idx] for idx in test_idx])
  149. with tf.Session(graph=tf.Graph()) as sess:
  150. char_input,logits,target,length,crf_loss,trans,train_op = BiLSTM_CRF_tfmodel(sess,w2v_matrix)
  151. sess.run(tf.global_variables_initializer())
  152. saver = tf.train.Saver()
  153. epochs = 60
  154. batch_size = 300
  155. _test_loss = 10000.
  156. for epoch in range(epochs):
  157. for x_batch,y_batch,x_len_batch in batch_iter(train_x,train_y,train_x_len,batch_size=batch_size):
  158. # for x,y,l in zip(x_batch,y_batch,x_len_batch):
  159. # print(l,'=>',x)
  160. # print(y)
  161. train_loss,_ = sess.run([crf_loss,train_op],feed_dict={char_input:x_batch,target:y_batch,length:x_len_batch,})
  162. test_loss,_logits,_trans = sess.run([crf_loss,logits,trans],feed_dict={char_input:test_x,target:test_y,length:test_x_len})
  163. acc = getAcc(test_y,_logits,_trans,test_x_len)
  164. print("==>epoch:"+str(epoch))
  165. print("--test --"," acc:",acc,'test_loss:',test_loss)
  166. print("--train--","loss:",train_loss,"have_done")
  167. if test_loss<_test_loss:
  168. _test_loss = test_loss
  169. print("Saving-"+str(epoch)+"-model,test_loss:"+str(test_loss))
  170. saver.save(sess,"models/"+str(epoch)+"-"+str(acc)+"-"+str(test_loss)+"/model.ckpt")
  171. def batch_iter(x, y,x_len, batch_size=256):
  172. '''
  173. :param x: content2id
  174. :param y: label2id
  175. :param batch_size: 每次进入模型的句子数量
  176. :return:
  177. '''
  178. data_len = len(x)
  179. num_batch = int((data_len - 1) / batch_size) + 1 #计算一个epoch,需要多少次batch
  180. # indices = np.random.permutation(data_len) #生成随机数列
  181. # x_shuffle = x[indices]
  182. # y_shuffle = y[indices]
  183. # x_len_shuffle = x_len[indices]
  184. for i in range(num_batch):
  185. start_id = batch_size * i
  186. end_id = min(batch_size*(i+1), data_len)
  187. yield x[start_id:end_id], y[start_id:end_id],x_len[start_id:end_id]
  188. from sklearn.metrics import accuracy_score
  189. def getAcc(y_batch,logits,trans,lengths):
  190. index = 0
  191. small = -1000.0
  192. start = np.asarray([[small] * 4 + [0]])
  193. preds = []
  194. true_tags = []
  195. for score, length in zip(logits, lengths):
  196. score = score[:length]
  197. # pad = small * np.ones([length, 1])
  198. # logit = np.concatenate([score, pad], axis=1)
  199. # logit = np.concatenate([start, logit], axis=0)
  200. # path, _ = tf.contrib.crf.viterbi_decode(logit, trans)
  201. path, _ = tf.contrib.crf.viterbi_decode(score, trans)
  202. preds += path[0:]
  203. # preds += path[1:]
  204. index += 1
  205. for y, length in zip(y_batch, lengths):
  206. y = y.tolist()
  207. true_tags += y[: length]
  208. acc = accuracy_score(np.reshape(true_tags,(-1)), np.reshape(preds,(-1)))
  209. return acc
  210. def predict(articles,model_file):
  211. vocab_model = getModel_word()
  212. vocab, w2v_matrix = getVocabAndMatrix(vocab_model, Embedding_size=60)
  213. model_file = model_file
  214. sess = tf.Session(graph=tf.Graph())
  215. with sess:
  216. char_input, logits, target, length, crf_loss, trans, train_op = BiLSTM_CRF_tfmodel(sess, w2v_matrix)
  217. sess.run(tf.global_variables_initializer())
  218. saver = tf.train.Saver()
  219. saver.restore(sess, model_file)
  220. re_ner = re.compile("12+?3")
  221. article_ner_list = []
  222. count = 0
  223. for sentences in articles:
  224. count += 1
  225. print(count)
  226. sentence_len = [ len(sentence) for sentence in sentences]
  227. maxlen = max(sentence_len)
  228. sentences_x = []
  229. for sentence in sentences:
  230. sentence = list(sentence)
  231. sentence2id = [getIndexOfWord(word) for word in sentence]
  232. sentences_x.append(sentence2id)
  233. sentences_x = pad_sequences(sentences_x,maxlen=maxlen,padding="post", truncating="post")
  234. sentences_x = [np.array(x) for x in sentences_x]
  235. _logits,_trans = sess.run([logits,trans],feed_dict={char_input:np.array(sentences_x),length:sentence_len})
  236. viterbi_sequence = decode(logits=_logits,trans=_trans,sequence_lengths=sentence_len,tag_num=4)
  237. ner_list = []
  238. for _seq,sentence in zip(viterbi_sequence,sentences):
  239. seq_id = ''.join([str(s) for s in _seq])
  240. if re_ner.search(seq_id):
  241. # print("sentence: ",sentence)
  242. for _ner in re_ner.finditer(seq_id):
  243. start = _ner.start()
  244. end = _ner.end()
  245. n = sentence[start:end]
  246. # print(n,'<==>',start,end)
  247. ner_list.append((n,start,end))
  248. article_ner_list.append(ner_list)
  249. return article_ner_list
  250. def decode(logits, trans, sequence_lengths, tag_num):
  251. viterbi_sequences = []
  252. for logit, length in zip(logits, sequence_lengths):
  253. score = logit[:length]
  254. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  255. viterbi_sequences.append(viterbi_seq)
  256. return viterbi_sequences
  257. def test2():
  258. punishNo = {
  259. 'O': 0,
  260. 'PN_B': 1,
  261. 'PN_M': 2,
  262. 'PN_E': 3
  263. }
  264. data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\db_alldata.csv", index_col=0)
  265. train_data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\punishment_code_new.csv", index_col=0)
  266. punishNo_in_text = set()
  267. for textId in train_data['document_id']:
  268. punishNo_in_text.add(textId)
  269. for _ in range(1,2821):
  270. punishNo_in_text.add(_)
  271. punishNo_in_text = list(punishNo_in_text)
  272. data = data[data['document_id'].isin(punishNo_in_text)]
  273. data = data.dropna(subset=['text'])
  274. re_rule1 = re.compile('\[|\]')
  275. data['sentences'] = [re_rule1.sub('', sentences).split(',') for sentences in data['sentences']]
  276. data['sentences'] = [[int(s) for s in sentences] for sentences in data['sentences']]
  277. article_sentences = []
  278. for id,text,sentences in zip(data['document_id'],data['text'],data['sentences']):
  279. # if id in punishNo_in_text:
  280. sentences_count = len(sentences)
  281. sentence_list = []
  282. for i in range(sentences_count-1):
  283. sentence = text[sentences[i]:sentences[i+1]]
  284. sentence_list.append(sentence)
  285. article_sentences.append(sentence_list)
  286. model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  287. punishNo_ner = predict(article_sentences,model_file)
  288. data['punishNo_test'] = punishNo_ner
  289. punishNo_label = [[] for _ in range(13500)]
  290. for textId, begin, end, entity_text in zip(train_data['document_id'], train_data['begin_index'],
  291. train_data['end_index'],train_data['entity_text']):
  292. punishNo_label[textId].append((entity_text,begin,end))
  293. punishNo_right = []
  294. for id in data['document_id']:
  295. punishNo_right.append(punishNo_label[id])
  296. data['punishNo_right'] = punishNo_right
  297. test_res = []
  298. for test,label_list in zip(data['punishNo_test'],data['punishNo_right']):
  299. if set(test)==set(label_list):
  300. test_res.append(1)
  301. else:
  302. test_res.append(0)
  303. data['test_res'] = test_res
  304. data.to_excel("C:\\Users\\admin\\Desktop\\投诉处罚信息\\punishNo_test.xlsx",encoding='utf-8')
  305. def test():
  306. data = pd.read_csv("data/ALLDATA.csv", index_col=0)[500:600]
  307. model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  308. # data = data[35000:45000]
  309. sentences_list = []
  310. for sentences in data['sentences']:
  311. sentences = sentences.split("*#*>")
  312. sentences_list.append(sentences)
  313. print(len(sentences_list))
  314. pn_ner = predict(sentences_list,model_file)
  315. print('*'*20)
  316. print(len(pn_ner),pn_ner)
  317. data['ner_test'] = pn_ner
  318. print(data.head(3))
  319. # data.to_excel("C:\\Users\\admin\\Desktop\\投诉处罚信息\\已分类\\ALLDATA_re2-3.xlsx",encoding='utf-8')
  320. if __name__ == '__main__':
  321. # train()
  322. # test()
  323. model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  324. sentences_list = '行政处罚厦建招诉决【2019】34号。行政处罚厦建招诉决【2019】34号。行政处罚厦建招诉决【2019】34号。行政处罚厦建招诉决【2019】34号,'.split('。')
  325. pn_ner = predict([sentences_list], model_file)
  326. print(pn_ner)
  327. # test2()
  328. # data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\ALLDATA.csv",index_col=0)
  329. # sentences = data['sentences'][51313]
  330. # sentences = sentences.split("*#*>")
  331. # model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  332. # predict(sentences,model_file)
  333. pass