train.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. '''
  2. Created on 2019年4月22日
  3. @author: User
  4. '''
  5. import sys
  6. import os
  7. sys.path.append(os.path.abspath("../../.."))
  8. from BiddingKG.dl.common.Utils import *
  9. from keras.callbacks import ModelCheckpoint
  10. from BiddingKG.dl.common.models import *
  11. import pandas as pd
  12. import keras
  13. import numpy as np
  14. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  15. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  16. def loadTrainData(percent=0.9,line=False):
  17. # files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1","id_token_text_begin_end_label-selffool.pk1"]
  18. # files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  19. files = ["id_token_text_begin_end_label-moreTrue.pk"]
  20. data_x = []
  21. data_y = []
  22. #data_id = []
  23. test_x = []
  24. test_y = []
  25. test_id = []
  26. #_,_,_,_,_,test_id_before = load("all_data_selffool.pk_line")
  27. #test_id_before = set(test_id_before)
  28. dict_label_item = dict()
  29. #统计数据分布
  30. for file in files:
  31. data = load(file)
  32. for row in data:
  33. id = row[0]
  34. label = int(row[5])
  35. if label not in dict_label_item:
  36. dict_label_item[label] = set()
  37. dict_label_item[label].add(id)
  38. dict_label_num = dict()
  39. for _key in dict_label_item.keys():
  40. dict_label_num[_key] = int(len(dict_label_item[_key])*(1-percent))
  41. for file in files:
  42. data = load(file)
  43. _count = 0
  44. for row in data:
  45. #item_x = embedding_word(spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=100,center_include=True,word_flag=True), shape=(3,100,60))
  46. _span = spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=10,center_include=True,word_flag=True,text=row[2])
  47. item_x = encodeInput(_span, word_len=50, word_flag=True,userFool=False)
  48. if line:
  49. item_x = item_x[0]+item_x[1]+item_x[2]
  50. item_y = np.zeros([6])
  51. label = int(row[5])
  52. print(_span,label)
  53. _count += 1
  54. if label not in [0,1,2,3,4,5]:
  55. continue
  56. item_y[label] = 1
  57. if np.random.random()>0.5 and dict_label_num[label]>0:
  58. dict_label_num[label] -= 1
  59. test_x.append(item_x)
  60. test_y.append(item_y)
  61. test_id.append(row[0])
  62. else:
  63. data_x.append(item_x)
  64. data_y.append(item_y)
  65. #data_id.append(row[0])
  66. # if np.random.random()>percent:
  67. # # if row[0] not in test_id_before:
  68. # data_x.append(item_x)
  69. # data_y.append(item_y)
  70. # #data_id.append(row[0])
  71. # else:
  72. # test_x.append(item_x)
  73. # test_y.append(item_y)
  74. # test_id.append(row[0])
  75. print(np.shape(np.array(data_x)),np.shape(np.array(test_x)))
  76. print(dict_label_num)
  77. if line:
  78. return np.array(data_x),np.array(data_y),np.array(test_x),np.array(test_y),None,test_id
  79. else:
  80. return np.transpose(np.array(data_x),(1,0,2)),np.array(data_y),np.transpose(np.array(test_x),(1,0,2)),np.array(test_y),None,test_id
  81. def train():
  82. # data_pk = "all_data_selffool_before-10.pk"
  83. data_pk = "all_data_selffool_moretrue-10.pk"
  84. # data_pk = "all_data_selffool_all-10.pk"
  85. if os.path.exists(data_pk):
  86. train_x,train_y,test_x,test_y,_,test_id = load(data_pk)
  87. else:
  88. train_x,train_y,test_x,test_y,_,test_id = loadTrainData()
  89. save((train_x,train_y,test_x,test_y,_,test_id),data_pk)
  90. with tf.Session(graph=tf.Graph()).as_default() as sess:
  91. with sess.graph.as_default():
  92. # dict_key_value = load("dict_key_value.pk")
  93. # model = getBiLSTMModel(input_shape=(3,50,256), vocab=fool_char_to_id.keys(), embedding_weights=dict_key_value["bert/embeddings/word_embeddings:0"], classes=6)
  94. vocab,matrix = getVocabAndMatrix(getModel_word())
  95. # model = getBiLSTMModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  96. model = getBiLSTMModel_entity(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  97. # model = getTextCNNModel(input_shape=(2,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  98. '''
  99. for k,v in dict_key_value.items():
  100. if re.search("encoder",k) is not None:
  101. sess.run(tf.assign(sess.graph.get_tensor_by_name(k[13:]),v))
  102. print(k)
  103. '''
  104. #model = getTextCNNModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=weights, classes=6)
  105. # model.load_weights("log/ep044-loss0.142-val_loss0.200-f1_score0.934.h5",skip_mismatch=True,by_name=True)
  106. model.load_weights("log/min_val_loss_ep027-loss0.112-val_loss0.109-f1_score0.963.h5")
  107. #model.summary()
  108. #print("11111111111",sess.run(sess.graph.get_tensor_by_name("encoder/layer_0/attention/self/query/kernel:0")))
  109. callback = ModelCheckpoint(filepath="log/"+"min_val_loss_ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_f1_score:.3f}.h5",monitor="val_loss",save_best_only=True, save_weights_only=True, mode="min")
  110. callback1 = ModelCheckpoint(filepath="log/"+"min_loss_ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_f1_score:.3f}.h5",monitor="loss",save_best_only=True, save_weights_only=True, mode="min")
  111. history_model = model.fit(x=[train_x[0],train_x[1],train_x[2]],y=train_y,validation_data=([test_x[0],test_x[1],test_x[2]],test_y),epochs=600,batch_size=96,shuffle=True,callbacks=[callback,callback1])
  112. # history_model = model.fit(x=[train_x[0],train_x[1]],y=train_y,validation_data=([test_x[0],test_x[1]],test_y),epochs=600,batch_size=128,shuffle=True,callbacks=[callback,callback1])
  113. # history_model = model.fit(x=train_x,y=train_y,validation_data=(test_x,test_y),epochs=600,batch_size=128,shuffle=True,callbacks=[callback])
  114. #print("2222222222222",sess.run(sess.graph.get_tensor_by_name("encoder/layer_0/attention/self/query/kernel:0")))
  115. def test():
  116. _span = [':预算金额1000000元,中标金额', '1df元', ';']
  117. _input = encodeInput(_span, word_len=50, word_flag=True,userFool=True)
  118. print(_input)
  119. print(len(_input))
  120. print(len(_input[0]))
  121. print(len(_input[1]))
  122. print(len(_input[2]))
  123. def statis():
  124. df = pd.read_excel("测试数据_role-biws-biw0.xls")
  125. result = {"正确-词":0,
  126. "错误-词":0,
  127. "正确-字":0,
  128. "错误-字":0}
  129. for i in range(6):
  130. result["正确-词"+str(i)] = 0
  131. result["错误-词"+str(i)] = 0
  132. result["正确-字"+str(i)] = 0
  133. result["错误-字"+str(i)] = 0
  134. for label_ws,prob_ws,label_w,prob_w,label_true in zip(df["list_newlabel"],df["list_newprob"],df["list_newlabel_cnn"],df["list_newprob_cnn"],df["label_true"]):
  135. if int(label_ws)==int(label_true):
  136. key = "正确-词"
  137. result[key] += 1
  138. result[key+str(int(label_ws))]+=1
  139. else:
  140. key = "错误-词"
  141. result[key] += 1
  142. result[key+str(int(label_ws))]+=1
  143. if int(label_w)==int(label_true):
  144. key = "正确-字"
  145. result[key] += 1
  146. result[key+str(int(label_w))]+=1
  147. else:
  148. key = "错误-字"
  149. result[key] += 1
  150. result[key+str(int(label_w))]+=1
  151. data = []
  152. for key in result.keys():
  153. data.append([key,result[key]])
  154. data.sort(key=lambda x:x[0])
  155. for item in data:
  156. print(item)
  157. def val():
  158. data_pk = "all_data_selffool.pk_line"
  159. train_x,train_y,test_x,test_y,_,test_id = load(data_pk)
  160. vocab,matrix = getVocabAndMatrix(getModel_word())
  161. model = getBiLSTMModel(input_shape=(1,150,60), vocab=vocab, embedding_weights=matrix, classes=6)
  162. model.load_weights("log/ep064-loss0.585-val_loss0.634-f1_score0.927.h5")
  163. # predict_y = np.argmax(model.predict([test_x[0],test_x[1],test_x[2]]),-1)
  164. predict_y = np.argmax(model.predict(test_x),-1)
  165. dict_notTrue = dict()
  166. for _y,Y,_id in zip(predict_y,np.argmax(test_y,-1),test_id):
  167. if _y!=Y:
  168. dict_notTrue[_id] = [_y,Y]
  169. token_data = load("id_token_text_begin_end_label-selffool.pk1")
  170. test_before = []
  171. test_center = []
  172. test_after = []
  173. test_label = []
  174. test_predict = []
  175. for item in token_data:
  176. if item[0] in dict_notTrue:
  177. token = item[1]
  178. text = item[2]
  179. begin = item[3]
  180. end = item[4]
  181. predict,label = dict_notTrue[item[0]]
  182. _span = spanWindow(tokens=token,begin_index=begin,end_index=end,size=10,center_include=True,word_flag=True,text=text)
  183. before,center,after = _span
  184. test_before.append(before)
  185. test_center.append(center)
  186. test_after.append(after)
  187. test_label.append(label)
  188. test_predict.append(predict)
  189. data = {"test_before":test_before,"test_center":test_center,"test_after":test_after,"test_label":test_label,"test_predict":test_predict}
  190. df = pd.DataFrame(data)
  191. df.to_excel("val_bert_position.xls",columns=["test_before","test_center","test_after","test_label","test_predict"])
  192. def get_savedmodel():
  193. with tf.Session(graph=tf.Graph()).as_default() as sess:
  194. with sess.graph.as_default():
  195. vocab,matrix = getVocabAndMatrix(getModel_word(),Embedding_size=60)
  196. # model = getBiLSTMModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  197. model = getBiLSTMModel_entity(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  198. # model = getTextCNNModel(input_shape=(2,50,60), vocab=vocab, embedding_weights=matrix, classes=6)
  199. # filepath = "log/ep001-loss0.087-val_loss0.172-f1_score0.944.h5"
  200. filepath = "log/min_val_loss_ep034-loss0.070-val_loss0.068-f1_score0.975.h5"
  201. model.load_weights(filepath)
  202. tf.saved_model.simple_save(sess,
  203. "role_savedmodel/",
  204. inputs={"input0":model.input[0],
  205. "input1":model.input[1],
  206. "input2":model.input[2]},
  207. outputs={"outputs":model.output}
  208. )
  209. def get_tensorboard():
  210. with tf.Session(graph=tf.Graph()) as sess:
  211. tf.saved_model.loader.load(sess,export_dir="role_savedmodel",tags=["serve"])
  212. writer = tf.summary.FileWriter(graph=sess.graph,logdir="log2")
  213. def relabel():
  214. list_id = []
  215. list_before = []
  216. list_center = []
  217. list_after = []
  218. list_label = []
  219. files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  220. for file in files:
  221. data = load(file)
  222. _count = 0
  223. for row in data:
  224. #item_x = embedding_word(spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=100,center_include=True,word_flag=True), shape=(3,100,60))
  225. _span = spanWindow(tokens=row[1],begin_index=row[3],end_index=row[4],size=15,center_include=True,word_flag=True,text=row[2])
  226. _label = row[5]
  227. if int(_label) in [3,4]:
  228. list_id.append(row[0])
  229. list_before.append(_span[0])
  230. list_center.append(_span[1])
  231. list_after.append(_span[2])
  232. list_label.append(str(_label))
  233. df = pd.DataFrame({"list_id":list_id,
  234. "list_before":list_before,
  235. "list_center":list_center,
  236. "list_after":list_after,
  237. "list_label":list_label})
  238. df.to_excel("relabel_1.xls",columns=["list_id","list_before","list_center","list_after","list_label"])
  239. def generate_data():
  240. file_before = "D:\\myProject\\traindata\\"
  241. files = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  242. data = load(file_before+"id_token_text_begin_end_label-selffool.pk1")
  243. df = pd.read_excel(file_before+"relabel_1.xls")
  244. set_id = set(df["list_id"])
  245. for file in files:
  246. temp_data = load(file_before+file)
  247. for row in temp_data:
  248. if row[0] in set_id:
  249. # print(row)
  250. data.append(row)
  251. save(data,file_before+"id_token_text_begin_end_label-moreTrue.pk")
  252. if __name__=="__main__":
  253. # loadTrainData()
  254. train()
  255. # relabel()
  256. # generate_data()
  257. test()
  258. #statis()
  259. # val()
  260. # get_savedmodel()
  261. # get_tensorboard()
  262. pass