punishNo_tf.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import tensorflow as tf
  2. # from tensorflow.contrib.crf import crf_log_likelihood
  3. # from tensorflow.contrib.layers.python.layers import initializers
  4. import numpy as np
  5. import pandas as pd
  6. from zipfile import ZipFile
  7. import os
  8. import pickle
  9. from BiddingKG.dl.common.Utils import *
  10. from keras.preprocessing.sequence import pad_sequences
  11. # class BiLSTM_CRF_tf(object):
  12. # def __init__(self):
  13. def BiLSTM_CRF_tfmodel(sess,weights):
  14. BiRNN_Units = 140
  15. chunk_tags = {
  16. 'O': 0,
  17. 'PN_B': 1,
  18. 'PN_M': 2,
  19. 'PN_E': 3
  20. }
  21. def embedding_layer(input):
  22. embedding = tf.get_variable("embedding",initializer=np.array(weights,dtype=np.float32) if weights is not None else None,dtype=tf.float32)
  23. return tf.nn.embedding_lookup(params=embedding,ids=input)
  24. def BiLSTM_Layer(input,length):
  25. with tf.variable_scope("BiLSTM"):
  26. forward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Units//2,state_is_tuple=True)
  27. backward_cell = tf.contrib.rnn.BasicLSTMCell(BiRNN_Units//2,state_is_tuple=True)
  28. output, _ = tf.nn.bidirectional_dynamic_rnn(forward_cell,backward_cell,input,dtype=tf.float32,sequence_length=length)
  29. output = tf.concat(output,2)
  30. return output
  31. def CRF_layer(input,num_tags,BiRNN_Units,time_step):
  32. with tf.variable_scope("CRF"):
  33. with tf.variable_scope("hidden"):
  34. w_hidden = tf.get_variable(name='w_hidden',shape=(BiRNN_Units,BiRNN_Units//2),dtype=tf.float32,
  35. initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
  36. b_hidden = tf.get_variable(name='b_hidden',shape=(BiRNN_Units//2),dtype=tf.float32,initializer=tf.zeros_initializer())
  37. # print(input)
  38. input_reshape = tf.reshape(input,shape=(-1,BiRNN_Units))
  39. hidden = tf.tanh(tf.nn.xw_plus_b(input_reshape,w_hidden,b_hidden))
  40. with tf.variable_scope("output"):
  41. w_output = tf.get_variable(name='w_output',shape=(BiRNN_Units//2,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer(),regularizer=tf.contrib.layers.l2_regularizer(0.001))
  42. b_output = tf.get_variable(name='b_output',shape=(num_tags),dtype=tf.float32,initializer=tf.zeros_initializer())
  43. pred = tf.nn.xw_plus_b(hidden,w_output,b_output)
  44. logits_ = tf.reshape(pred,shape=(-1,time_step,num_tags),name='logits')
  45. return logits_
  46. def layer_loss(input,true_target,num_tags,length):
  47. with tf.variable_scope("crf_loss"):
  48. trans = tf.get_variable(name='transitons',shape=(num_tags,num_tags),dtype=tf.float32,initializer=initializers.xavier_initializer())
  49. log_likelihood,trans = crf_log_likelihood(inputs=input,tag_indices=true_target,transition_params=trans,sequence_lengths=length)
  50. return tf.reduce_mean(-log_likelihood),trans
  51. with sess.graph.as_default():
  52. char_input = tf.placeholder(name='char_input',shape=(None,None),dtype=tf.int32)
  53. target = tf.placeholder(name='target',shape=(None,None),dtype=tf.int32)
  54. length = tf.placeholder(name='length',shape=(None,),dtype=tf.int32)
  55. # keepprob = tf.placeholder(name='keepprob',dtype=tf.float32)
  56. _embedding = embedding_layer(char_input)
  57. _shape = tf.shape(char_input)
  58. batch_size = _shape[0]
  59. step_size = _shape[-1]
  60. bilstm = BiLSTM_Layer(_embedding,length)
  61. _logits = CRF_layer(bilstm,num_tags=len(chunk_tags),BiRNN_Units=BiRNN_Units,time_step=step_size)
  62. crf_loss,trans = layer_loss(_logits,true_target=target,num_tags=len(chunk_tags),length=length)
  63. global_step = tf.Variable(0,trainable=False)
  64. with tf.variable_scope("optimizer"):
  65. opt = tf.train.AdamOptimizer(0.002)
  66. grads_vars = opt.compute_gradients(crf_loss)
  67. capped_grads_vars = [[tf.clip_by_value(g,-5,5),v] for g,v in grads_vars]
  68. train_op = opt.apply_gradients(capped_grads_vars,global_step)
  69. return char_input,_logits,target,length,crf_loss,trans,train_op
  70. def train():
  71. vocab_model = getModel_word()
  72. vocab, w2v_matrix = getVocabAndMatrix(vocab_model, Embedding_size=60)
  73. # print(w2v_matrix)
  74. punishNo = {
  75. 'O': 0,
  76. 'PN_B': 1,
  77. 'PN_M': 2,
  78. 'PN_E': 3
  79. }
  80. punishNo_2 = {
  81. 'O': np.array([1, 0, 0, 0]),
  82. 'PN_B': np.array([0, 1, 0, 0]),
  83. 'PN_M': np.array([0, 0, 1, 0]),
  84. 'PN_E': np.array([0, 0, 0, 1])
  85. }
  86. data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\db_alldata.csv", index_col=0)
  87. train_data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\punishment_code_new.csv", index_col=0)
  88. train_data['text'] = [data['text'][data['document_id'] == id] for id in train_data['document_id']]
  89. data_x = []
  90. data_y = []
  91. articles_label = ['' for _ in range(13500)]
  92. punishNo_in_text = set()
  93. for textId, begin, end, entity_text, text in zip(train_data['document_id'], train_data['begin_index'],
  94. train_data['end_index'],
  95. train_data['entity_text'], train_data['text']):
  96. punishNo_in_text.add(textId)
  97. text = list(text)[0]
  98. l = len(text)
  99. if not articles_label[textId]:
  100. articles_label[textId] = ['O' for _ in range(l)]
  101. articles_label[textId][begin] = 'PN_B'
  102. articles_label[textId][end - 1] = 'PN_E'
  103. for i in range(begin + 1, end - 1):
  104. articles_label[textId][i] = 'PN_M'
  105. punishNo_in_text = list(punishNo_in_text)
  106. # 取含数字的负样本
  107. data = data.dropna(subset=['text'])
  108. re_rule1 = re.compile('\[|\]')
  109. data['sentences'] = [re_rule1.sub('', sentences).split(',') for sentences in data['sentences']]
  110. data['sentences'] = [[int(s) for s in sentences] for sentences in data['sentences']]
  111. re_rule2 = re.compile("[\d,.]{4,}")
  112. for id, article, sentences in zip(data['document_id'], data['text'], data['sentences']):
  113. if id < 2826 or id in punishNo_in_text:
  114. # print(id)
  115. article = str(article)
  116. l = len(article)
  117. text_word = list(article)
  118. text_word_index = [getIndexOfWord(word) for word in text_word]
  119. sentence_count = len(sentences)
  120. if articles_label[id]:
  121. label_list = articles_label[id]
  122. else:
  123. label_list = ['O' for _ in range(l)]
  124. for i in range(sentence_count - 1):
  125. if re_rule2.search(article[sentences[i]:sentences[i + 1]]):
  126. data_x.append(np.array(text_word_index[sentences[i]:sentences[i + 1]]))
  127. data_y.append(np.array(label_list[sentences[i]:sentences[i + 1]]))
  128. data_x = np.array(data_x)
  129. x_len = [250 if len(x)>250 else len(x) for x in data_x]
  130. data_x = pad_sequences(data_x, maxlen=250, padding="post", truncating="post")
  131. # train_x = train_x.reshape(-1)
  132. data_y = [np.array([punishNo[_y] for _y in y]) for y in data_y]
  133. # data_y = np.array(data_y).reshape(-1)
  134. data_y = np.array(data_y)
  135. data_y = pad_sequences(data_y, maxlen=250, padding="post", truncating="post")
  136. # print(data_x[:5])
  137. # print(data_y[:5])
  138. # data_x = np.array(list(data_x))
  139. # data_y = np.array(list(data_y))
  140. indices = np.random.permutation(data_x.shape[0])
  141. count = len(data_x)
  142. test_count = int(0.2 * count)
  143. test_idx, train_idx = indices[:test_count], indices[test_count:]
  144. # print(test_idx)
  145. train_x, test_x = data_x[train_idx, :], data_x[test_idx, :]
  146. train_y, test_y = data_y[train_idx, :], data_y[test_idx, :]
  147. train_x_len = np.array([x_len[idx] for idx in train_idx])
  148. test_x_len = np.array([x_len[idx] for idx in test_idx])
  149. with tf.Session(graph=tf.Graph()) as sess:
  150. char_input,logits,target,length,crf_loss,trans,train_op = BiLSTM_CRF_tfmodel(sess,w2v_matrix)
  151. sess.run(tf.global_variables_initializer())
  152. saver = tf.train.Saver()
  153. epochs = 60
  154. batch_size = 300
  155. _test_loss = 10000.
  156. for epoch in range(epochs):
  157. for x_batch,y_batch,x_len_batch in batch_iter(train_x,train_y,train_x_len,batch_size=batch_size):
  158. # for x,y,l in zip(x_batch,y_batch,x_len_batch):
  159. # print(l,'=>',x)
  160. # print(y)
  161. train_loss,_ = sess.run([crf_loss,train_op],feed_dict={char_input:x_batch,target:y_batch,length:x_len_batch,})
  162. test_loss,_logits,_trans = sess.run([crf_loss,logits,trans],feed_dict={char_input:test_x,target:test_y,length:test_x_len})
  163. acc = getAcc(test_y,_logits,_trans,test_x_len)
  164. print("==>epoch:"+str(epoch))
  165. print("--test --"," acc:",acc,'test_loss:',test_loss)
  166. print("--train--","loss:",train_loss,"have_done")
  167. if test_loss<_test_loss:
  168. _test_loss = test_loss
  169. print("Saving-"+str(epoch)+"-model,test_loss:"+str(test_loss))
  170. saver.save(sess,"models/"+str(epoch)+"-"+str(acc)+"-"+str(test_loss)+"/model.ckpt")
  171. def batch_iter(x, y,x_len, batch_size=256):
  172. '''
  173. :param x: content2id
  174. :param y: label2id
  175. :param batch_size: 每次进入模型的句子数量
  176. :return:
  177. '''
  178. data_len = len(x)
  179. num_batch = int((data_len - 1) / batch_size) + 1 #计算一个epoch,需要多少次batch
  180. # indices = np.random.permutation(data_len) #生成随机数列
  181. # x_shuffle = x[indices]
  182. # y_shuffle = y[indices]
  183. # x_len_shuffle = x_len[indices]
  184. for i in range(num_batch):
  185. start_id = batch_size * i
  186. end_id = min(batch_size*(i+1), data_len)
  187. yield x[start_id:end_id], y[start_id:end_id],x_len[start_id:end_id]
  188. from sklearn.metrics import accuracy_score
  189. def getAcc(y_batch,logits,trans,lengths):
  190. index = 0
  191. small = -1000.0
  192. start = np.asarray([[small] * 4 + [0]])
  193. preds = []
  194. true_tags = []
  195. for score, length in zip(logits, lengths):
  196. score = score[:length]
  197. # pad = small * np.ones([length, 1])
  198. # logit = np.concatenate([score, pad], axis=1)
  199. # logit = np.concatenate([start, logit], axis=0)
  200. # path, _ = tf.contrib.crf.viterbi_decode(logit, trans)
  201. path, _ = viterbi_decode(score, trans)
  202. preds += path[0:]
  203. # preds += path[1:]
  204. index += 1
  205. for y, length in zip(y_batch, lengths):
  206. y = y.tolist()
  207. true_tags += y[: length]
  208. acc = accuracy_score(np.reshape(true_tags,(-1)), np.reshape(preds,(-1)))
  209. return acc
  210. def predict(articles,model_file):
  211. vocab_model = getModel_word()
  212. vocab, w2v_matrix = getVocabAndMatrix(vocab_model, Embedding_size=60)
  213. model_file = model_file
  214. sess = tf.Session(graph=tf.Graph())
  215. with sess:
  216. char_input, logits, target, length, crf_loss, trans, train_op = BiLSTM_CRF_tfmodel(sess, w2v_matrix)
  217. sess.run(tf.global_variables_initializer())
  218. saver = tf.train.Saver()
  219. saver.restore(sess, model_file)
  220. re_ner = re.compile("12+?3")
  221. article_ner_list = []
  222. count = 0
  223. for sentences in articles:
  224. count += 1
  225. print(count)
  226. sentence_len = [ len(sentence) for sentence in sentences]
  227. maxlen = max(sentence_len)
  228. sentences_x = []
  229. for sentence in sentences:
  230. sentence = list(sentence)
  231. sentence2id = [getIndexOfWord(word) for word in sentence]
  232. sentences_x.append(sentence2id)
  233. sentences_x = pad_sequences(sentences_x,maxlen=maxlen,padding="post", truncating="post")
  234. sentences_x = [np.array(x) for x in sentences_x]
  235. _logits,_trans = sess.run([logits,trans],feed_dict={char_input:np.array(sentences_x),length:sentence_len})
  236. viterbi_sequence = decode(logits=_logits,trans=_trans,sequence_lengths=sentence_len,tag_num=4)
  237. ner_list = []
  238. for _seq,sentence in zip(viterbi_sequence,sentences):
  239. seq_id = ''.join([str(s) for s in _seq])
  240. if re_ner.search(seq_id):
  241. # print("sentence: ",sentence)
  242. for _ner in re_ner.finditer(seq_id):
  243. start = _ner.start()
  244. end = _ner.end()
  245. n = sentence[start:end]
  246. # print(n,'<==>',start,end)
  247. ner_list.append((n,start,end))
  248. article_ner_list.append(ner_list)
  249. return article_ner_list
  250. def decode(logits, trans, sequence_lengths, tag_num):
  251. viterbi_sequences = []
  252. for logit, length in zip(logits, sequence_lengths):
  253. score = logit[:length]
  254. viterbi_seq, viterbi_score = viterbi_decode(score, trans)
  255. viterbi_sequences.append(viterbi_seq)
  256. return viterbi_sequences
  257. def test2():
  258. punishNo = {
  259. 'O': 0,
  260. 'PN_B': 1,
  261. 'PN_M': 2,
  262. 'PN_E': 3
  263. }
  264. data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\db_alldata.csv", index_col=0)
  265. train_data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\punishment_code_new.csv", index_col=0)
  266. punishNo_in_text = set()
  267. for textId in train_data['document_id']:
  268. punishNo_in_text.add(textId)
  269. for _ in range(1,2821):
  270. punishNo_in_text.add(_)
  271. punishNo_in_text = list(punishNo_in_text)
  272. data = data[data['document_id'].isin(punishNo_in_text)]
  273. data = data.dropna(subset=['text'])
  274. re_rule1 = re.compile('\[|\]')
  275. data['sentences'] = [re_rule1.sub('', sentences).split(',') for sentences in data['sentences']]
  276. data['sentences'] = [[int(s) for s in sentences] for sentences in data['sentences']]
  277. article_sentences = []
  278. for id,text,sentences in zip(data['document_id'],data['text'],data['sentences']):
  279. # if id in punishNo_in_text:
  280. sentences_count = len(sentences)
  281. sentence_list = []
  282. for i in range(sentences_count-1):
  283. sentence = text[sentences[i]:sentences[i+1]]
  284. sentence_list.append(sentence)
  285. article_sentences.append(sentence_list)
  286. model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  287. punishNo_ner = predict(article_sentences,model_file)
  288. data['punishNo_test'] = punishNo_ner
  289. punishNo_label = [[] for _ in range(13500)]
  290. for textId, begin, end, entity_text in zip(train_data['document_id'], train_data['begin_index'],
  291. train_data['end_index'],train_data['entity_text']):
  292. punishNo_label[textId].append((entity_text,begin,end))
  293. punishNo_right = []
  294. for id in data['document_id']:
  295. punishNo_right.append(punishNo_label[id])
  296. data['punishNo_right'] = punishNo_right
  297. test_res = []
  298. for test,label_list in zip(data['punishNo_test'],data['punishNo_right']):
  299. if set(test)==set(label_list):
  300. test_res.append(1)
  301. else:
  302. test_res.append(0)
  303. data['test_res'] = test_res
  304. data.to_excel("C:\\Users\\admin\\Desktop\\投诉处罚信息\\punishNo_test.xlsx",encoding='utf-8')
  305. def test():
  306. data = pd.read_csv("data/ALLDATA.csv", index_col=0)[500:600]
  307. model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  308. # data = data[35000:45000]
  309. sentences_list = []
  310. for sentences in data['sentences']:
  311. sentences = sentences.split("*#*>")
  312. sentences_list.append(sentences)
  313. print(len(sentences_list))
  314. pn_ner = predict(sentences_list,model_file)
  315. print('*'*20)
  316. print(len(pn_ner),pn_ner)
  317. data['ner_test'] = pn_ner
  318. print(data.head(3))
  319. # data.to_excel("C:\\Users\\admin\\Desktop\\投诉处罚信息\\已分类\\ALLDATA_re2-3.xlsx",encoding='utf-8')
  320. if __name__ == '__main__':
  321. # train()
  322. # test()
  323. model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  324. sentences_list = '行政处罚厦建招诉决【2019】34号。行政处罚厦建招诉决【2019】34号。行政处罚厦建招诉决【2019】34号。行政处罚厦建招诉决【2019】34号,'.split('。')
  325. pn_ner = predict([sentences_list], model_file)
  326. print(pn_ner)
  327. # test2()
  328. # data = pd.read_csv("C:\\Users\\admin\\Desktop\\投诉处罚信息\\ALLDATA.csv",index_col=0)
  329. # sentences = data['sentences'][51313]
  330. # sentences = sentences.split("*#*>")
  331. # model_file = "models/21-0.9990081295021194-0.3647936/model.ckpt"
  332. # predict(sentences,model_file)
  333. pass