retrain.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. '''
  2. Created on 2019年1月15日
  3. @author: User
  4. '''
  5. import os
  6. import sys
  7. sys.path.append(os.path.abspath("../../"))
  8. import tensorflow as tf
  9. # from tensorflow.contrib.crf import crf_log_likelihood
  10. import json
  11. import numpy as np
  12. from BiddingKG.dl.foolnltk.label import labelEntity,readlabeldata,getContext
  13. from sklearn.metrics import accuracy_score
  14. from bi_lstm_crf import *
  15. from BiddingKG.dl.BertNer.BertCRF import *
  16. from zipfile import ZipFile
  17. import math
  18. import random
  19. import codecs
  20. OOV_STR = "<OOV>"
  21. def _load_map_file(path, char_map_name, id_map_name):
  22. with ZipFile(path) as myzip:
  23. with myzip.open('all_map.json') as myfile:
  24. content = myfile.readline()
  25. content = content.decode()
  26. data = json.loads(content)
  27. return data.get(char_map_name), data.get(id_map_name)
  28. def loss_layer(project_logits,y_target,trans,lengths,batch_size,num_steps,num_tags=25):
  29. with tf.variable_scope("crf_loss1"):
  30. small = -1000.0
  31. start_logits = tf.concat(
  32. [small * tf.ones(shape=[batch_size, 1, num_tags]), tf.zeros(shape=[batch_size, 1, 1])],
  33. axis=-1)
  34. pad_logits = tf.cast(small * tf.ones([batch_size, num_steps, 1]), tf.float32)
  35. logits = tf.concat([project_logits, pad_logits], axis=-1)
  36. logits = tf.concat([start_logits, logits], axis=1)
  37. targets = tf.concat(
  38. [tf.cast(num_tags * tf.ones([batch_size, 1]), tf.int32), y_target], axis=-1)
  39. log_likelihood, trans = crf_log_likelihood(
  40. inputs=logits,
  41. tag_indices=targets,
  42. transition_params=trans,
  43. sequence_lengths=lengths + 1)
  44. return tf.reduce_mean(-log_likelihood)
  45. def getAcc(y_batch,logits,trans,lengths):
  46. index = 0
  47. small = -1000.0
  48. start = np.asarray([[small] * 25 + [0]])
  49. preds = []
  50. true_tags = []
  51. for score, length in zip(logits, lengths):
  52. score = score[:length]
  53. pad = small * np.ones([length, 1])
  54. logit = np.concatenate([score, pad], axis=1)
  55. logit = np.concatenate([start, logit], axis=0)
  56. path, _ = viterbi_decode(logit, trans)
  57. preds += path[1:]
  58. index += 1
  59. for y, length in zip(y_batch, lengths):
  60. y = y.tolist()
  61. true_tags += y[: length]
  62. acc = accuracy_score(np.reshape(true_tags,(-1)), np.reshape(preds,(-1)))
  63. return acc
  64. def encode(list_sent_label):
  65. data = []
  66. char_to_id, id_to_seg = _load_map_file("data/map.zip", "char_map", "ner_map")
  67. id_to_tag = {int(k):v for k,v in id_to_seg.items()}
  68. tag_to_id = {v:int(k) for k,v in id_to_seg.items()}
  69. list_sent = []
  70. list_label = []
  71. for sent_label in list_sent_label:
  72. _sent = []
  73. _label = []
  74. for item in sent_label:
  75. if item[0]=="。":
  76. _sent.append(item[0])
  77. _label.append(item[1])
  78. list_sent.append(_sent)
  79. list_label.append(_label)
  80. _sent = []
  81. _label = []
  82. else:
  83. _sent.append(item[0])
  84. _label.append(item[1])
  85. if sent_label[-1][0]!="。":
  86. list_sent.append(_sent)
  87. list_label.append(_label)
  88. for _sent,_label in zip(list_sent,list_label):
  89. sent_ids = [char_to_id.get(w) if w in char_to_id else char_to_id.get("<OOV>") for w in _sent]
  90. sent_tags = [tag_to_id.get(t) for t in _label]
  91. data.append([sent_ids,sent_tags,len(_sent)])
  92. return data
  93. class NERBatchManager(object):
  94. def __init__(self, data, batch_size):
  95. self.MAX_LEN = 300
  96. self.char_to_id, self.id_to_seg = _load_map_file("data/map.zip", "char_map", "ner_map")
  97. self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
  98. self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
  99. self.batch_data = self.sort_and_pad(data, batch_size)
  100. self.len_data = len(self.batch_data)
  101. print(self.len_data)
  102. def sort_and_pad(self, data, batch_size):
  103. num_batch = int(math.ceil(len(data) / batch_size))
  104. sorted_data = sorted(data, key=lambda x: len(x[0]))
  105. batch_data = list()
  106. for i in range(num_batch):
  107. _item = self.pad_data(sorted_data[i*batch_size : (i+1)*batch_size])
  108. if _item is not None:
  109. batch_data.append(_item)
  110. return batch_data
  111. def pad_data(self,data):
  112. batch_strings = []
  113. batch_targets = []
  114. batch_lengths = []
  115. max_length = max([len(sentence[0]) for sentence in data])
  116. if max_length>1000:
  117. return None
  118. for line in data:
  119. text_ids,text_labels,text_len = line
  120. padding = [0] * (max_length - len(text_ids))
  121. padding_l = [self.tag_to_id.get("O")]*(max_length-len(text_labels))
  122. batch_strings.append(text_ids+padding)
  123. batch_targets.append(text_labels+padding_l)
  124. batch_lengths.append(text_len)
  125. return [np.array(batch_strings), np.array(batch_targets), np.array(batch_lengths)]
  126. def iter_batch(self, shuffle=False):
  127. if shuffle:
  128. random.shuffle(self.batch_data)
  129. for idx in range(self.len_data):
  130. yield self.batch_data[idx]
  131. def train():
  132. '''
  133. list_sent_label = labelEntity()
  134. data = encode(list_sent_label)
  135. '''
  136. bilstm = BertCRF()
  137. list_context = getContext("ner_train.txt")
  138. '''
  139. data_train = readlabeldata("ner_train.txt",list_context)
  140. data_train_split = readlabeldata("ner_train_split.txt",list_context)
  141. data_label_split = readlabeldata("ner_label_split.txt",list_context)
  142. data_label = readlabeldata("ner_label.txt",list_context)
  143. data_entity_found = readlabeldata("company_found.txt",list_context,keep_prob=0.33)
  144. data_entity_cleaned = readlabeldata("cleanedEntity.txt",list_context)
  145. data = data_train + data_entity_found+data_train_split +data_entity_cleaned+data_label_split+ data_label
  146. '''
  147. data = readlabeldata("selffool_train.txt", list_context)
  148. #data = readlabeldata("ner_label.txt",[])
  149. test_len = 128
  150. ner_bm_train = NERBatchManager(data[:-test_len], batch_size=96)
  151. test_x = []
  152. test_y = []
  153. test_length = []
  154. max_length = 0
  155. for line in data[-test_len:]:
  156. text_ids,text_labels,text_len = line
  157. if max_length<300 and text_len>max_length and text_len<300:
  158. max_length = text_len
  159. print("test_len",max_length)
  160. for line in data[-test_len:]:
  161. text_ids,text_labels,text_len = line
  162. if text_len>max_length:
  163. continue
  164. padding = [0] * (max_length - len(text_ids))
  165. padding_l = [ner_bm_train.tag_to_id.get("O")]*(max_length-len(text_labels))
  166. test_x.append(text_ids+padding)
  167. test_y.append(text_labels+padding_l)
  168. test_length.append(text_len)
  169. test_x = np.array(test_x)
  170. test_y = np.array(test_y)
  171. test_length = np.array(test_length)
  172. with bilstm.sess as sess:
  173. input_x,y_target,lengths,dropout,logits,trans,loss,train_op = bilstm.getNodes()
  174. bilstm.initVariables()
  175. #bilstm.restore()
  176. epochs = 100
  177. saver = tf.train.Saver(max_to_keep=epochs)
  178. with codecs.open("log.txt","w",encoding="utf8") as f:
  179. test_index = 0
  180. for i in range(epochs):
  181. _batch = 0
  182. for x_batch,y_batch,length_batch in ner_bm_train.iter_batch(shuffle=True):
  183. print(np.shape(x_batch))
  184. _batch += 1
  185. feed_dict = {
  186. input_x: x_batch,
  187. y_target: y_batch,
  188. lengths: length_batch.reshape(-1, ),
  189. dropout: 0.75
  190. }
  191. _,_loss,_logits,_trans = sess.run([train_op,loss,logits,trans], feed_dict=feed_dict)
  192. acc = getAcc(y_batch, _logits, _trans, length_batch.reshape(-1, ))
  193. print("epochs:",i,"batch:",_batch,_loss,acc)
  194. f.write("epochs:"+str(i)+"/batch:"+str(_batch)+"/loss:"+str(_loss)+"/acc:"+str(acc)+"\n")
  195. print("testing")
  196. feed_dict = {
  197. input_x: test_x,
  198. y_target: test_y,
  199. lengths: test_length.reshape(-1, ),
  200. dropout: 1.0
  201. }
  202. _loss,_logits,_trans = sess.run([loss,logits,trans], feed_dict=feed_dict)
  203. acc = getAcc(test_y, _logits, _trans, test_length.reshape(-1, ))
  204. print("---test","epochs:",i,acc)
  205. f.write("---test"+"epochs:"+str(test_index)+"-"+str(i)+"/loss:"+str(_loss)+"/acc:"+str(acc)+"\n")
  206. print("============","epochs:",str(i),"loss:",str(_loss)+"done")
  207. saver.save(sess,'model/'+str(test_index)+"-"+str(i)+'/model.ckpt')
  208. #saver.save(sess,"/model/model-epoch:"+str(i)+"-loss:"+str(_loss)[0:6]+"-acc:"+str(acc)[0:6]+".ckpt",global_step=1)
  209. if __name__=="__main__":
  210. train()