train.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. '''
  2. Created on 2019年4月22日
  3. @author: User
  4. '''
  5. import sys
  6. import os
  7. sys.path.append(os.path.abspath("../../.."))
  8. from BiddingKG.dl.common.Utils import *
  9. from keras.callbacks import ModelCheckpoint
  10. from BiddingKG.dl.common.models import *
  11. import pandas as pd
  12. import keras
  13. import numpy as np
  14. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  15. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  16. def loadTrainData(percent=0.9,line=False):
  17. # files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1","id_token_text_begin_end_label-selffool.pk1"]
  18. # files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  19. files = ["id_token_text_begin_end_label-moreTrue.pk"]
  20. data_x = []
  21. data_y = []
  22. #data_id = []
  23. test_x = []
  24. test_y = []
  25. test_id = []
  26. #_,_,_,_,_,test_id_before = load("all_data_selffool.pk_line")
  27. #test_id_before = set(test_id_before)
  28. dict_label_item = dict()
  29. #统计数据分布
  30. for file in files:
  31. data = load(file)
  32. for row in data:
  33. id = row[0]
  34. label = int(row[5])
  35. if label not in dict_label_item:
  36. dict_label_item[label] = set()
  37. dict_label_item[label].add(id)
  38. dict_label_num = dict()
  39. for _key in dict_label_item.keys():
  40. dict_label_num[_key] = int(len(dict_label_item[_key])*(1-percent))
  41. for file in files:
  42. data = load(file)
  43. _count = 0
  44. for row in data:
  45. #item_x = embedding_word(spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=100,center_include=True,word_flag=True), shape=(3,100,60))
  46. _span = spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=10,center_include=True,word_flag=True,text=row[2])
  47. item_x = encodeInput(_span, word_len=50, word_flag=True,userFool=False)
  48. if line:
  49. item_x = item_x[0]+item_x[1]+item_x[2]
  50. item_y = np.zeros([6])
  51. label = int(row[5])
  52. print(_span,label)
  53. _count += 1
  54. if label not in [0,1,2,3,4,5]:
  55. continue
  56. item_y[label] = 1
  57. if np.random.random()>0.5 and dict_label_num[label]>0:
  58. dict_label_num[label] -= 1
  59. test_x.append(item_x)
  60. test_y.append(item_y)
  61. test_id.append(row[0])
  62. else:
  63. data_x.append(item_x)
  64. data_y.append(item_y)
  65. #data_id.append(row[0])
  66. # if np.random.random()>percent:
  67. # # if row[0] not in test_id_before:
  68. # data_x.append(item_x)
  69. # data_y.append(item_y)
  70. # #data_id.append(row[0])
  71. # else:
  72. # test_x.append(item_x)
  73. # test_y.append(item_y)
  74. # test_id.append(row[0])
  75. print(np.shape(np.array(data_x)),np.shape(np.array(test_x)))
  76. print(dict_label_num)
  77. if line:
  78. return np.array(data_x),np.array(data_y),np.array(test_x),np.array(test_y),None,test_id
  79. else:
  80. return np.transpose(np.array(data_x),(1,0,2)),np.array(data_y),np.transpose(np.array(test_x),(1,0,2)),np.array(test_y),None,test_id
  81. def train():
  82. # data_pk = "all_data_selffool_before-10.pk"
  83. data_pk = "all_data_selffool_moretrue-10.pk"
  84. # data_pk = "all_data_selffool_all-10.pk"
  85. if os.path.exists(data_pk):
  86. train_x,train_y,test_x,test_y,_,test_id = load(data_pk)
  87. else:
  88. train_x,train_y,test_x,test_y,_,test_id = loadTrainData()
  89. save((train_x,train_y,test_x,test_y,_,test_id),data_pk)
  90. with tf.Session(graph=tf.Graph()).as_default() as sess:
  91. with sess.graph.as_default():
  92. # dict_key_value = load("dict_key_value.pk")
  93. # model = getBiLSTMModel(input_shape=(3,50,256), vocab=fool_char_to_id.keys(), embedding_weights=dict_key_value["bert/embeddings/word_embeddings:0"], classes=6)
  94. vocab,matrix = getVocabAndMatrix(getModel_word())
  95. # model = getBiLSTMModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  96. model = getBiLSTMModel_entity(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  97. # model = getTextCNNModel(input_shape=(2,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  98. '''
  99. for k,v in dict_key_value.items():
  100. if re.search("encoder",k) is not None:
  101. sess.run(tf.assign(sess.graph.get_tensor_by_name(k[13:]),v))
  102. print(k)
  103. '''
  104. #model = getTextCNNModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=weights, classes=6)
  105. # model.load_weights("log/ep044-loss0.142-val_loss0.200-f1_score0.934.h5",skip_mismatch=True,by_name=True)
  106. model.load_weights("log/min_val_loss_ep027-loss0.112-val_loss0.109-f1_score0.963.h5")
  107. #model.summary()
  108. #print("11111111111",sess.run(sess.graph.get_tensor_by_name("encoder/layer_0/attention/self/query/kernel:0")))
  109. callback = ModelCheckpoint(filepath="log/"+"min_val_loss_ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_f1_score:.3f}.h5",monitor="val_loss",save_best_only=True, save_weights_only=True, mode="min")
  110. callback1 = ModelCheckpoint(filepath="log/"+"min_loss_ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_f1_score:.3f}.h5",monitor="loss",save_best_only=True, save_weights_only=True, mode="min")
  111. history_model = model.fit(x=[train_x[0],train_x[1],train_x[2]],y=train_y,validation_data=([test_x[0],test_x[1],test_x[2]],test_y),epochs=600,batch_size=96,shuffle=True,callbacks=[callback,callback1])
  112. # history_model = model.fit(x=[train_x[0],train_x[1]],y=train_y,validation_data=([test_x[0],test_x[1]],test_y),epochs=600,batch_size=128,shuffle=True,callbacks=[callback,callback1])
  113. # history_model = model.fit(x=train_x,y=train_y,validation_data=(test_x,test_y),epochs=600,batch_size=128,shuffle=True,callbacks=[callback])
  114. #print("2222222222222",sess.run(sess.graph.get_tensor_by_name("encoder/layer_0/attention/self/query/kernel:0")))
  115. def test():
  116. _span = [':预算金额1000000元,中标金额', '1df元', ';']
  117. _input = encodeInput(_span, word_len=50, word_flag=True,userFool=True)
  118. print(_input)
  119. def statis():
  120. df = pd.read_excel("测试数据_role-biws-biw0.xls")
  121. result = {"正确-词":0,
  122. "错误-词":0,
  123. "正确-字":0,
  124. "错误-字":0}
  125. for i in range(6):
  126. result["正确-词"+str(i)] = 0
  127. result["错误-词"+str(i)] = 0
  128. result["正确-字"+str(i)] = 0
  129. result["错误-字"+str(i)] = 0
  130. for label_ws,prob_ws,label_w,prob_w,label_true in zip(df["list_newlabel"],df["list_newprob"],df["list_newlabel_cnn"],df["list_newprob_cnn"],df["label_true"]):
  131. if int(label_ws)==int(label_true):
  132. key = "正确-词"
  133. result[key] += 1
  134. result[key+str(int(label_ws))]+=1
  135. else:
  136. key = "错误-词"
  137. result[key] += 1
  138. result[key+str(int(label_ws))]+=1
  139. if int(label_w)==int(label_true):
  140. key = "正确-字"
  141. result[key] += 1
  142. result[key+str(int(label_w))]+=1
  143. else:
  144. key = "错误-字"
  145. result[key] += 1
  146. result[key+str(int(label_w))]+=1
  147. data = []
  148. for key in result.keys():
  149. data.append([key,result[key]])
  150. data.sort(key=lambda x:x[0])
  151. for item in data:
  152. print(item)
  153. def val():
  154. data_pk = "all_data_selffool.pk_line"
  155. train_x,train_y,test_x,test_y,_,test_id = load(data_pk)
  156. vocab,matrix = getVocabAndMatrix(getModel_word())
  157. model = getBiLSTMModel(input_shape=(1,150,60), vocab=vocab, embedding_weights=matrix, classes=6)
  158. model.load_weights("log/ep064-loss0.585-val_loss0.634-f1_score0.927.h5")
  159. # predict_y = np.argmax(model.predict([test_x[0],test_x[1],test_x[2]]),-1)
  160. predict_y = np.argmax(model.predict(test_x),-1)
  161. dict_notTrue = dict()
  162. for _y,Y,_id in zip(predict_y,np.argmax(test_y,-1),test_id):
  163. if _y!=Y:
  164. dict_notTrue[_id] = [_y,Y]
  165. token_data = load("id_token_text_begin_end_label-selffool.pk1")
  166. test_before = []
  167. test_center = []
  168. test_after = []
  169. test_label = []
  170. test_predict = []
  171. for item in token_data:
  172. if item[0] in dict_notTrue:
  173. token = item[1]
  174. text = item[2]
  175. begin = item[3]
  176. end = item[4]
  177. predict,label = dict_notTrue[item[0]]
  178. _span = spanWindow(tokens=token,begin_index=begin,end_index=end,size=10,center_include=True,word_flag=True,text=text)
  179. before,center,after = _span
  180. test_before.append(before)
  181. test_center.append(center)
  182. test_after.append(after)
  183. test_label.append(label)
  184. test_predict.append(predict)
  185. data = {"test_before":test_before,"test_center":test_center,"test_after":test_after,"test_label":test_label,"test_predict":test_predict}
  186. df = pd.DataFrame(data)
  187. df.to_excel("val_bert_position.xls",columns=["test_before","test_center","test_after","test_label","test_predict"])
  188. def get_savedmodel():
  189. with tf.Session(graph=tf.Graph()).as_default() as sess:
  190. with sess.graph.as_default():
  191. vocab,matrix = getVocabAndMatrix(getModel_word(),Embedding_size=60)
  192. # model = getBiLSTMModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  193. model = getBiLSTMModel_entity(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  194. # model = getTextCNNModel(input_shape=(2,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  195. # filepath = "log/ep001-loss0.087-val_loss0.172-f1_score0.944.h5"
  196. filepath = "log/min_val_loss_ep034-loss0.070-val_loss0.068-f1_score0.975.h5"
  197. model.load_weights(filepath)
  198. tf.saved_model.simple_save(sess,
  199. "role_savedmodel/",
  200. inputs={"input0":model.input[0],
  201. "input1":model.input[1],
  202. "input2":model.input[2]},
  203. outputs={"outputs":model.output}
  204. )
  205. def get_tensorboard():
  206. with tf.Session(graph=tf.Graph()) as sess:
  207. tf.saved_model.loader.load(sess,export_dir="role_savedmodel",tags=["serve"])
  208. writer = tf.summary.FileWriter(graph=sess.graph,logdir="log2")
  209. def relabel():
  210. list_id = []
  211. list_before = []
  212. list_center = []
  213. list_after = []
  214. list_label = []
  215. files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  216. for file in files:
  217. data = load(file)
  218. _count = 0
  219. for row in data:
  220. #item_x = embedding_word(spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=100,center_include=True,word_flag=True), shape=(3,100,60))
  221. _span = spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=15,center_include=True,word_flag=True,text=row[2])
  222. _label = row[5]
  223. if int(_label) in [3,4]:
  224. list_id.append(row[0])
  225. list_before.append(_span[0])
  226. list_center.append(_span[1])
  227. list_after.append(_span[2])
  228. list_label.append(str(_label))
  229. df = pd.DataFrame({"list_id":list_id,
  230. "list_before":list_before,
  231. "list_center":list_center,
  232. "list_after":list_after,
  233. "list_label":list_label})
  234. df.to_excel("relabel_1.xls",columns=["list_id","list_before","list_center","list_after","list_label"])
  235. def generate_data():
  236. files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  237. data = load("id_token_text_begin_end_label-selffool.pk1")
  238. df = pd.read_excel("relabel_1.xls")
  239. set_id = set(df["list_id"])
  240. for file in files:
  241. temp_data = load(file)
  242. for row in temp_data:
  243. if row[0] in set_id:
  244. data.append(row)
  245. save(data,"id_token_text_begin_end_label-moreTrue.pk")
  246. if __name__=="__main__":
  247. # loadTrainData()
  248. # train()
  249. # relabel()
  250. # generate_data()
  251. # test()
  252. #statis()
  253. # val()
  254. get_savedmodel()
  255. # get_tensorboard()