train.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. '''
  2. Created on 2019年3月26日
  3. @author: User
  4. '''
  5. import sys
  6. import os
  7. sys.path.append(os.path.abspath("../.."))
  8. from BiddingKG.dl.common.models import *
  9. from keras.callbacks import ModelCheckpoint
  10. import numpy as np
  11. import time
  12. from BiddingKG.dl.common.Utils import *
  13. import tensorflow as tf
  14. from generateData import *
  15. def train():
  16. model = getTextCNNModel()
  17. train_x, train_y = getData("train.xls")
  18. test_x, test_y = getData("test.xls")
  19. callback = ModelCheckpoint('log/' + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.4f}.h5',
  20. monitor="val_loss", verbose=1, save_best_only=True, mode="min")
  21. model.fit(x=train_x, y=train_y, batch_size=96, epochs=400, callbacks=[callback], shuffle=True,
  22. validation_data=(test_x, test_y))
  23. def train1():
  24. data_pk = "all_data.pk"
  25. if os.path.exists(data_pk):
  26. train_x, train_y, test_x, test_y, test_text = load(data_pk)
  27. else:
  28. train_x, train_y, test_x, test_y, test_text = getTrainData()
  29. save((train_x, train_y, test_x, test_y, test_text), data_pk)
  30. with tf.Session(graph=tf.Graph()) as sess:
  31. with sess.graph.as_default():
  32. print("11111111")
  33. vocab,matrix = getVocabAndMatrix(getModel_word())
  34. model = getBiLSTMModel(input_shape=(1,50,60), vocab=vocab, embedding_weights=matrix, classes=2)
  35. # model = getTextCNNModel(input_shape=(1,30,60), vocab=vocab, embedding_weights=weights, classes=2)
  36. print("22222222")
  37. callback = ModelCheckpoint('log/' + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.4f}.h5',
  38. monitor="val_loss", verbose=1, save_best_only=True,save_weights_only=True, mode="min")
  39. model.fit(x=train_x, y=train_y, batch_size=128, epochs=400, callbacks=[callback], shuffle=True,
  40. validation_data=(test_x, test_y))
  41. def vali():
  42. data_pk = "all_data.pk"
  43. train_x, train_y, test_x, test_y, test_text = load(data_pk)
  44. model = models.load_model("log/loss_ep106-loss0.008-val_loss0.134-f10.9768.h5",
  45. custom_objects={"precision": precision, "recall": recall, "f1_score": f1_score,
  46. "Attention": Attention})
  47. predict = model.predict(test_x)
  48. predict_y = np.argmax(predict, 1)
  49. list_filename = []
  50. list_text = []
  51. list_label = []
  52. list_predict = []
  53. list_prob = []
  54. data = []
  55. for y, y_, text, prob in zip(np.argmax(test_y, 1), predict_y, test_text, predict):
  56. if y == y_:
  57. data.append([text[0], text[1], y, y_, prob[y_]])
  58. data.sort(key=lambda x: x[2])
  59. for item in data:
  60. list_filename.append(item[0])
  61. list_text.append(item[1])
  62. list_label.append(item[2])
  63. list_predict.append(item[3])
  64. list_prob.append(item[4])
  65. df = pd.DataFrame(
  66. {"list_filename": list_filename, "list_text": list_text, "list_label": list_label, "list_predict": list_predict,
  67. "list_prob": list_prob})
  68. df.to_excel("vali_true.xls", columns=["list_filename", "list_text", "list_label", "list_predict", "list_prob"])
  69. list_filename = []
  70. list_text = []
  71. list_label = []
  72. list_predict = []
  73. list_prob = []
  74. data = []
  75. for y, y_, text, prob in zip(np.argmax(test_y, 1), predict_y, test_text, predict):
  76. if y != y_:
  77. data.append([text[0], text[1], y, y_, prob[y_]])
  78. data.sort(key=lambda x: x[2])
  79. for item in data:
  80. list_filename.append(item[0])
  81. list_text.append(item[1])
  82. list_label.append(item[2])
  83. list_predict.append(item[3])
  84. list_prob.append(item[4])
  85. df = pd.DataFrame(
  86. {"list_filename": list_filename, "list_text": list_text, "list_label": list_label, "list_predict": list_predict,
  87. "list_prob": list_prob})
  88. df.to_excel("vali_wrong.xls", columns=["list_filename", "list_text", "list_label", "list_predict", "list_prob"])
  89. def test(list_text):
  90. x = []
  91. for text in list_text:
  92. x.append(encoding(text))
  93. x = np.array(x)
  94. # x = np.expand_dims(encoding(text),0)
  95. # test_x,test_y = getData("test.xls")
  96. model = getTextCNNModel()
  97. model.load_weights("log/ep082-loss0.044-val_loss0.126-f10.9592.h5")
  98. a = time.time()
  99. predict_y = model.predict(x)
  100. print("cost", time.time() - a)
  101. # model.save("model/model_form.model.hdf5")
  102. return predict_y
  103. def getBestThreshold():
  104. def getAccurancyRecall(predict, threshold, test_y):
  105. nums = 0
  106. counts = 0
  107. for item, _max, y in zip(predict, np.argmax(predict, 1), np.argmax(test_y, 1)):
  108. if item[_max] > threshold:
  109. if _max == y:
  110. nums += 1
  111. counts += 1
  112. precision = nums / counts
  113. recall = nums / len(test_y)
  114. return 2 * ((precision * recall) / (precision + recall))
  115. # return precision,recall
  116. model = getTextCNNModel()
  117. model.load_weights("model/model_form.model.hdf5")
  118. test_x, test_y = getData("test.xls")
  119. predict_y = model.predict(test_x)
  120. threshold = 0.5
  121. x = []
  122. y = []
  123. while (threshold < 1):
  124. x.append(threshold)
  125. t0 = getAccurancyRecall(predict_y, threshold, test_y)
  126. y.append(t0)
  127. print(threshold, getAccurancyRecall(predict_y, threshold, test_y))
  128. threshold += 0.001
  129. plt.plot(x, y)
  130. plt.show()
  131. def save_form_model():
  132. with tf.Session(graph=tf.Graph()).as_default() as sess:
  133. with sess.graph.as_default():
  134. vocab,matrix = getVocabAndMatrix(getModel_word())
  135. model = getBiLSTMModel(input_shape=(1,50,60), vocab=vocab, embedding_weights=matrix, classes=2)
  136. model.load_weights(filepath="log/ep029-loss0.044-val_loss0.057-f10.9788.h5")
  137. tf.saved_model.simple_save(sess,
  138. "./form_savedmodel/",
  139. inputs={"inputs":model.input},
  140. outputs = {"outputs":model.output})
  141. if __name__ == "__main__":
  142. # train()
  143. # print(test(["序号|项目名称|中选人"]))
  144. # getBestThreshold()
  145. # train1()
  146. # vali()
  147. save_form_model()