train.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. '''
  2. Created on 2019年3月26日
  3. @author: User
  4. '''
  5. import sys
  6. import os
  7. os.environ['KERAS_BACKEND']='tensorflow'
  8. sys.path.append(os.path.abspath("../../.."))
  9. from BiddingKG.dl.common.models import *
  10. from keras.callbacks import ModelCheckpoint
  11. import numpy as np
  12. import time
  13. from BiddingKG.dl.common.Utils import *
  14. import tensorflow as tf
  15. from generateData import *
  16. def train():
  17. model = getTextCNNModel()
  18. train_x, train_y = getData("train.xls")
  19. test_x, test_y = getData("test.xls")
  20. callback = ModelCheckpoint('log/' + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.4f}.h5',
  21. monitor="val_loss", verbose=1, save_best_only=True, mode="min")
  22. model.fit(x=train_x, y=train_y, batch_size=96, epochs=400, callbacks=[callback], shuffle=True,
  23. validation_data=(test_x, test_y))
  24. def train1():
  25. data_pk = "./traindata/all_data.pk"
  26. if os.path.exists(data_pk):
  27. train_x, train_y, test_x, test_y, test_text = load(data_pk)
  28. else:
  29. # train_x, train_y, test_x, test_y, test_text = getTrainData()
  30. # save((train_x, train_y, test_x, test_y, test_text), data_pk)
  31. train_x,train_y,test_x,test_y = getTrainData_jsonTable()
  32. with tf.Session(graph=tf.Graph()) as sess:
  33. with sess.graph.as_default():
  34. print("11111111")
  35. vocab,matrix = getVocabAndMatrix(getModel_word())
  36. model = getBiLSTMModel(input_shape=(1,50,60), vocab=vocab, embedding_weights=matrix, classes=2)
  37. # model = getTextCNNModel(input_shape=(1,30,60), vocab=vocab, embedding_weights=weights, classes=2)
  38. print("22222222")
  39. callback = ModelCheckpoint('log/' + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.4f}.h5',
  40. monitor="val_loss", verbose=1, save_best_only=True,save_weights_only=True, mode="min")
  41. model.fit(x=train_x, y=train_y, batch_size=128, epochs=400, callbacks=[callback], shuffle=True,
  42. validation_data=(test_x, test_y))
  43. def vali():
  44. data_pk = "all_data.pk"
  45. train_x, train_y, test_x, test_y, test_text = load(data_pk)
  46. model = models.load_model("../../dl_dev/form/log/loss_ep106-loss0.008-val_loss0.134-f10.9768.h5",
  47. custom_objects={"precision": precision, "recall": recall, "f1_score": f1_score,
  48. "Attention": Attention})
  49. predict = model.predict(test_x)
  50. predict_y = np.argmax(predict, 1)
  51. list_filename = []
  52. list_text = []
  53. list_label = []
  54. list_predict = []
  55. list_prob = []
  56. data = []
  57. for y, y_, text, prob in zip(np.argmax(test_y, 1), predict_y, test_text, predict):
  58. if y == y_:
  59. data.append([text[0], text[1], y, y_, prob[y_]])
  60. data.sort(key=lambda x: x[2])
  61. for item in data:
  62. list_filename.append(item[0])
  63. list_text.append(item[1])
  64. list_label.append(item[2])
  65. list_predict.append(item[3])
  66. list_prob.append(item[4])
  67. df = pd.DataFrame(
  68. {"list_filename": list_filename, "list_text": list_text, "list_label": list_label, "list_predict": list_predict,
  69. "list_prob": list_prob})
  70. df.to_excel("vali_true.xls", columns=["list_filename", "list_text", "list_label", "list_predict", "list_prob"])
  71. list_filename = []
  72. list_text = []
  73. list_label = []
  74. list_predict = []
  75. list_prob = []
  76. data = []
  77. for y, y_, text, prob in zip(np.argmax(test_y, 1), predict_y, test_text, predict):
  78. if y != y_:
  79. data.append([text[0], text[1], y, y_, prob[y_]])
  80. data.sort(key=lambda x: x[2])
  81. for item in data:
  82. list_filename.append(item[0])
  83. list_text.append(item[1])
  84. list_label.append(item[2])
  85. list_predict.append(item[3])
  86. list_prob.append(item[4])
  87. df = pd.DataFrame(
  88. {"list_filename": list_filename, "list_text": list_text, "list_label": list_label, "list_predict": list_predict,
  89. "list_prob": list_prob})
  90. df.to_excel("vali_wrong.xls", columns=["list_filename", "list_text", "list_label", "list_predict", "list_prob"])
  91. def test(list_text):
  92. x = []
  93. for text in list_text:
  94. x.append(encoding(text))
  95. x = np.array(x)
  96. # x = np.expand_dims(encoding(text),0)
  97. # test_x,test_y = getData("test.xls")
  98. model = getTextCNNModel()
  99. model.load_weights("log/ep082-loss0.044-val_loss0.126-f10.9592.h5")
  100. a = time.time()
  101. predict_y = model.predict(x)
  102. print("cost", time.time() - a)
  103. # model.save("model/model_form.model.hdf5")
  104. return predict_y
  105. def getBestThreshold():
  106. def getAccurancyRecall(predict, threshold, test_y):
  107. nums = 0
  108. counts = 0
  109. for item, _max, y in zip(predict, np.argmax(predict, 1), np.argmax(test_y, 1)):
  110. if item[_max] > threshold:
  111. if _max == y:
  112. nums += 1
  113. counts += 1
  114. precision = nums / counts
  115. recall = nums / len(test_y)
  116. return 2 * ((precision * recall) / (precision + recall))
  117. # return precision,recall
  118. model = getTextCNNModel()
  119. model.load_weights("model/model_form.model.hdf5")
  120. test_x, test_y = getData("test.xls")
  121. predict_y = model.predict(test_x)
  122. threshold = 0.5
  123. x = []
  124. y = []
  125. while (threshold < 1):
  126. x.append(threshold)
  127. t0 = getAccurancyRecall(predict_y, threshold, test_y)
  128. y.append(t0)
  129. print(threshold, getAccurancyRecall(predict_y, threshold, test_y))
  130. threshold += 0.001
  131. plt.plot(x, y)
  132. plt.show()
  133. def save_form_model():
  134. with tf.Session(graph=tf.Graph()).as_default() as sess:
  135. with sess.graph.as_default():
  136. vocab,matrix = getVocabAndMatrix(getModel_word())
  137. model = getBiLSTMModel(input_shape=(1,50,60), vocab=vocab, embedding_weights=matrix, classes=2)
  138. model.load_weights(filepath="../../dl_dev/form/log/ep029-loss0.044-val_loss0.057-f10.9788.h5")
  139. tf.saved_model.simple_save(sess,
  140. "./form_savedmodel/",
  141. inputs={"inputs":model.input},
  142. outputs = {"outputs":model.output})
  143. from BiddingKG.dl.form.generateData import getTrainData_jsonTable
  144. from BiddingKG.dl.form.model import get_context_form_model
  145. def train_context():
  146. vocab,vocab_matrix = getVocabAndMatrix(getModel_word(), Embedding_size=60)
  147. sess = tf.Session(graph=tf.Graph())
  148. epochs = 100
  149. batch_size = 1000
  150. with sess.graph.as_default():
  151. node_dict = get_context_form_model(vocab_len=len(vocab_matrix),char_dim=60,lstm_dim=24,context_dim=36,res_dim=16,residual_stacks=5)
  152. sess.run(tf.global_variables_initializer())
  153. sess.run(tf.assign(node_dict["embedding"],np.array(vocab_matrix)))
  154. train_x,train_y,_ = getTrainData_jsonTable(begin=0,end=20000)
  155. test_x,test_y,text_text = getTrainData_jsonTable(begin=-2000,end=-1,return_text=True)
  156. saver = tf.train.Saver(max_to_keep=epochs)
  157. for _e in range(epochs):
  158. batch_begin = 0
  159. avg_loss = 0
  160. avg_precision = 0
  161. _count = 0
  162. while(batch_begin<len(train_x)):
  163. feed_dict = {node_dict["input"]:train_x[batch_begin:batch_begin+batch_size],
  164. node_dict["label"]:train_y[batch_begin:batch_begin+batch_size]}
  165. _,_loss,_precision,_recall = sess.run([node_dict["train"],node_dict["loss"],node_dict["precision"],node_dict["recall"]],feed_dict=feed_dict)
  166. _count += 1
  167. avg_loss += _loss
  168. avg_precision += _precision
  169. if _count %10==0:
  170. print("training %d epoch %d/%d loss:%5f precision:%5f"%(_e,batch_begin,len(train_x),avg_loss/_count,avg_precision/_count))
  171. batch_begin += batch_size
  172. train_loss = avg_loss/_count
  173. train_precision = avg_precision/_count
  174. batch_begin = 0
  175. avg_loss = 0
  176. avg_precision = 0
  177. _count = 0
  178. while(batch_begin<len(test_x)):
  179. feed_dict = {node_dict["input"]:test_x[batch_begin:batch_begin+batch_size],
  180. node_dict["label"]:test_y[batch_begin:batch_begin+batch_size]}
  181. _loss,_precision,_recall = sess.run([node_dict["loss"],node_dict["precision"],node_dict["recall"]],feed_dict=feed_dict)
  182. _count += 1
  183. avg_loss += _loss
  184. avg_precision += _precision
  185. if _count %10==0:
  186. print("testing %d epoch %d/%d loss:%5f precision:%5f"%(_e,batch_begin,len(test_x),avg_loss/_count,avg_precision/_count))
  187. batch_begin += batch_size
  188. saver.save(sess,"./traindata/log/epoch%d-train_loss%5f-train_precision%5f-val_loss%5f-val_precision%5f.ckpt"%(_e,train_loss,train_precision,avg_loss/_count,avg_precision/_count))
  189. def predict_context():
  190. import os
  191. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  192. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  193. model_file = "../../dl_dev/form/traindata/log/epoch18-train_loss0.075330-train_precision0.947963-val_loss0.088985-val_precision0.947385.ckpt"
  194. sess = tf.Session(graph=tf.Graph())
  195. with sess.graph.as_default():
  196. vocab,vocab_matrix = getVocabAndMatrix(getModel_word(), Embedding_size=60)
  197. node_dict = get_context_form_model(vocab_len=len(vocab_matrix),char_dim=60,lstm_dim=24,context_dim=36,res_dim=16,residual_stacks=5)
  198. saver = tf.train.Saver()
  199. saver.restore(sess,model_file)
  200. # tf.saved_model.simple_save(session=sess,
  201. # export_dir="./form_context_savedmodel",
  202. # inputs={"inputs":node_dict["input"]},
  203. # outputs={"outputs":node_dict["output"]})
  204. # return
  205. test_x,test_y,test_text = getTrainData_jsonTable(begin=-2000,end=-1,return_text=True)
  206. batch_begin = 0
  207. batch_size = 2000
  208. while batch_begin<len(test_x):
  209. feed_dict = {node_dict["input"]:test_x[batch_begin:batch_begin+batch_size]}
  210. batch_out = sess.run([node_dict["output"]],feed_dict=feed_dict)
  211. batch_label = test_y[batch_begin:batch_begin+batch_size]
  212. batch_text = test_text[batch_begin:batch_begin+batch_size]
  213. for _out,_label,_text in zip(batch_out[0],batch_label,batch_text):
  214. if np.argmax(_out)!=np.argmax(_label):
  215. print("======")
  216. print(_text)
  217. print(_text[4])
  218. print(_out)
  219. print(_label)
  220. print("============")
  221. batch_begin += batch_size
  222. if __name__ == "__main__":
  223. # train()
  224. # print(test(["序号|项目名称|中选人"]))
  225. # getBestThreshold()
  226. # train1()
  227. # vali()
  228. # save_form_model()
  229. train_context()
  230. # predict_context()