retrain.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. '''
  2. Created on 2019年1月15日
  3. @author: User
  4. '''
  5. import os
  6. import sys
  7. sys.path.append(os.path.abspath("../../"))
  8. import tensorflow as tf
  9. # from tensorflow.contrib.crf import crf_log_likelihood
  10. import json
  11. import numpy as np
  12. from BiddingKG.dl.foolnltk.label import readlabeldata,getContext
  13. from sklearn.metrics import accuracy_score
  14. from bi_lstm_crf import *
  15. from zipfile import ZipFile
  16. import math
  17. import random
  18. import codecs
  19. OOV_STR = "<OOV>"
  20. def _load_map_file(path, char_map_name, id_map_name):
  21. with ZipFile(path) as myzip:
  22. with myzip.open('all_map.json') as myfile:
  23. content = myfile.readline()
  24. content = content.decode()
  25. data = json.loads(content)
  26. return data.get(char_map_name), data.get(id_map_name)
  27. def loss_layer(project_logits,y_target,trans,lengths,batch_size,num_steps,num_tags=25):
  28. with tf.variable_scope("crf_loss1"):
  29. small = -1000.0
  30. start_logits = tf.concat(
  31. [small * tf.ones(shape=[batch_size, 1, num_tags]), tf.zeros(shape=[batch_size, 1, 1])],
  32. axis=-1)
  33. pad_logits = tf.cast(small * tf.ones([batch_size, num_steps, 1]), tf.float32)
  34. logits = tf.concat([project_logits, pad_logits], axis=-1)
  35. logits = tf.concat([start_logits, logits], axis=1)
  36. targets = tf.concat(
  37. [tf.cast(num_tags * tf.ones([batch_size, 1]), tf.int32), y_target], axis=-1)
  38. log_likelihood, trans = crf_log_likelihood(
  39. inputs=logits,
  40. tag_indices=targets,
  41. transition_params=trans,
  42. sequence_lengths=lengths + 1)
  43. return tf.reduce_mean(-log_likelihood)
  44. def getAcc(y_batch,logits,trans,lengths):
  45. index = 0
  46. small = -1000.0
  47. start = np.asarray([[small] * 25 + [0]])
  48. preds = []
  49. true_tags = []
  50. for score, length in zip(logits, lengths):
  51. score = score[:length]
  52. pad = small * np.ones([length, 1])
  53. logit = np.concatenate([score, pad], axis=1)
  54. logit = np.concatenate([start, logit], axis=0)
  55. path, _ = viterbi_decode(logit, trans)
  56. preds += path[1:]
  57. index += 1
  58. for y, length in zip(y_batch, lengths):
  59. y = y.tolist()
  60. true_tags += y[: length]
  61. acc = accuracy_score(np.reshape(true_tags,(-1)), np.reshape(preds,(-1)))
  62. return acc
  63. def encode(list_sent_label):
  64. data = []
  65. char_to_id, id_to_seg = _load_map_file("data/map.zip", "char_map", "ner_map")
  66. id_to_tag = {int(k):v for k,v in id_to_seg.items()}
  67. tag_to_id = {v:int(k) for k,v in id_to_seg.items()}
  68. list_sent = []
  69. list_label = []
  70. for sent_label in list_sent_label:
  71. _sent = []
  72. _label = []
  73. for item in sent_label:
  74. if item[0]=="。":
  75. _sent.append(item[0])
  76. _label.append(item[1])
  77. list_sent.append(_sent)
  78. list_label.append(_label)
  79. _sent = []
  80. _label = []
  81. else:
  82. _sent.append(item[0])
  83. _label.append(item[1])
  84. if sent_label[-1][0]!="。":
  85. list_sent.append(_sent)
  86. list_label.append(_label)
  87. for _sent,_label in zip(list_sent,list_label):
  88. sent_ids = [char_to_id.get(w) if w in char_to_id else char_to_id.get("<OOV>") for w in _sent]
  89. sent_tags = [tag_to_id.get(t) for t in _label]
  90. data.append([sent_ids,sent_tags,len(_sent)])
  91. return data
  92. class NERBatchManager(object):
  93. def __init__(self, data, batch_size):
  94. self.MAX_LEN = 300
  95. self.char_to_id, self.id_to_seg = _load_map_file("data/map.zip", "char_map", "ner_map")
  96. self.id_to_tag = {int(k):v for k,v in self.id_to_seg.items()}
  97. self.tag_to_id = {v:int(k) for k,v in self.id_to_seg.items()}
  98. self.batch_data = self.sort_and_pad(data, batch_size)
  99. self.len_data = len(self.batch_data)
  100. print(self.len_data)
  101. def sort_and_pad(self, data, batch_size):
  102. num_batch = int(math.ceil(len(data) / batch_size))
  103. sorted_data = sorted(data, key=lambda x: len(x[0]))
  104. batch_data = list()
  105. for i in range(num_batch):
  106. _item = self.pad_data(sorted_data[i*batch_size : (i+1)*batch_size])
  107. if _item is not None:
  108. batch_data.append(_item)
  109. return batch_data
  110. def pad_data(self,data):
  111. batch_strings = []
  112. batch_targets = []
  113. batch_lengths = []
  114. max_length = max([len(sentence[0]) for sentence in data])
  115. if max_length>1000:
  116. return None
  117. for line in data:
  118. text_ids,text_labels,text_len = line
  119. padding = [0] * (max_length - len(text_ids))
  120. padding_l = [self.tag_to_id.get("O")]*(max_length-len(text_labels))
  121. batch_strings.append(text_ids+padding)
  122. batch_targets.append(text_labels+padding_l)
  123. batch_lengths.append(text_len)
  124. return [np.array(batch_strings), np.array(batch_targets), np.array(batch_lengths)]
  125. def iter_batch(self, shuffle=False):
  126. if shuffle:
  127. random.shuffle(self.batch_data)
  128. for idx in range(self.len_data):
  129. yield self.batch_data[idx]
  130. def train():
  131. '''
  132. list_sent_label = labelEntity()
  133. data = encode(list_sent_label)
  134. '''
  135. bilstm = BertCRF()
  136. list_context = getContext("ner_train.txt")
  137. '''
  138. data_train = readlabeldata("ner_train.txt",list_context)
  139. data_train_split = readlabeldata("ner_train_split.txt",list_context)
  140. data_label_split = readlabeldata("ner_label_split.txt",list_context)
  141. data_label = readlabeldata("ner_label.txt",list_context)
  142. data_entity_found = readlabeldata("company_found.txt",list_context,keep_prob=0.33)
  143. data_entity_cleaned = readlabeldata("cleanedEntity.txt",list_context)
  144. data = data_train + data_entity_found+data_train_split +data_entity_cleaned+data_label_split+ data_label
  145. '''
  146. data = readlabeldata("selffool_train.txt", list_context)
  147. #data = readlabeldata("ner_label.txt",[])
  148. test_len = 128
  149. ner_bm_train = NERBatchManager(data[:-test_len], batch_size=96)
  150. test_x = []
  151. test_y = []
  152. test_length = []
  153. max_length = 0
  154. for line in data[-test_len:]:
  155. text_ids,text_labels,text_len = line
  156. if max_length<300 and text_len>max_length and text_len<300:
  157. max_length = text_len
  158. print("test_len",max_length)
  159. for line in data[-test_len:]:
  160. text_ids,text_labels,text_len = line
  161. if text_len>max_length:
  162. continue
  163. padding = [0] * (max_length - len(text_ids))
  164. padding_l = [ner_bm_train.tag_to_id.get("O")]*(max_length-len(text_labels))
  165. test_x.append(text_ids+padding)
  166. test_y.append(text_labels+padding_l)
  167. test_length.append(text_len)
  168. test_x = np.array(test_x)
  169. test_y = np.array(test_y)
  170. test_length = np.array(test_length)
  171. with bilstm.sess as sess:
  172. input_x,y_target,lengths,dropout,logits,trans,loss,train_op = bilstm.getNodes()
  173. bilstm.initVariables()
  174. #bilstm.restore()
  175. epochs = 100
  176. saver = tf.train.Saver(max_to_keep=epochs)
  177. with codecs.open("log.txt","w",encoding="utf8") as f:
  178. test_index = 0
  179. for i in range(epochs):
  180. _batch = 0
  181. for x_batch,y_batch,length_batch in ner_bm_train.iter_batch(shuffle=True):
  182. print(np.shape(x_batch))
  183. _batch += 1
  184. feed_dict = {
  185. input_x: x_batch,
  186. y_target: y_batch,
  187. lengths: length_batch.reshape(-1, ),
  188. dropout: 0.75
  189. }
  190. _,_loss,_logits,_trans = sess.run([train_op,loss,logits,trans], feed_dict=feed_dict)
  191. acc = getAcc(y_batch, _logits, _trans, length_batch.reshape(-1, ))
  192. print("epochs:",i,"batch:",_batch,_loss,acc)
  193. f.write("epochs:"+str(i)+"/batch:"+str(_batch)+"/loss:"+str(_loss)+"/acc:"+str(acc)+"\n")
  194. print("testing")
  195. feed_dict = {
  196. input_x: test_x,
  197. y_target: test_y,
  198. lengths: test_length.reshape(-1, ),
  199. dropout: 1.0
  200. }
  201. _loss,_logits,_trans = sess.run([loss,logits,trans], feed_dict=feed_dict)
  202. acc = getAcc(test_y, _logits, _trans, test_length.reshape(-1, ))
  203. print("---test","epochs:",i,acc)
  204. f.write("---test"+"epochs:"+str(test_index)+"-"+str(i)+"/loss:"+str(_loss)+"/acc:"+str(acc)+"\n")
  205. print("============","epochs:",str(i),"loss:",str(_loss)+"done")
  206. saver.save(sess,'model/'+str(test_index)+"-"+str(i)+'/model.ckpt')
  207. #saver.save(sess,"/model/model-epoch:"+str(i)+"-loss:"+str(_loss)[0:6]+"-acc:"+str(acc)[0:6]+".ckpt",global_step=1)
  208. if __name__=="__main__":
  209. train()