train.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. '''
  2. Created on 2019年4月11日
  3. @author: User
  4. '''
  5. import os
  6. import sys
  7. import h5py
  8. sys.path.append(os.path.abspath("../../.."))
  9. import pandas as pd
  10. import gensim
  11. import numpy as np
  12. import math
  13. from keras.callbacks import ModelCheckpoint
  14. from BiddingKG.dl.common.Utils import *
  15. from BiddingKG.dl.common.models import getBiLSTMModel
  16. import tensorflow as tf
  17. from keras.models import load_model
  18. def embedding(datas,shape):
  19. '''
  20. @summary:查找词汇对应的词向量
  21. @param:
  22. datas:词汇的list
  23. shape:结果的shape
  24. @return: array,返回对应shape的词嵌入
  25. '''
  26. model_w2v = getModel_word()
  27. embed = np.zeros(shape)
  28. length = shape[1]
  29. out_index = 0
  30. #print(datas)
  31. for data in datas:
  32. index = 0
  33. for item in str(data)[-shape[1]:]:
  34. if index>=length:
  35. break
  36. if item in model_w2v.vocab:
  37. embed[out_index][index] = model_w2v[item]
  38. index += 1
  39. else:
  40. #embed[out_index][index] = model_w2v['unk']
  41. index += 1
  42. out_index += 1
  43. return embed
  44. def labeling(label,out_len=3):
  45. out = np.zeros((out_len))
  46. out[label] = 1
  47. return out
  48. def getTrainData(percent=0.9):
  49. df = pd.read_excel("批量.xls")
  50. train_x = []
  51. train_y = []
  52. test_x = []
  53. test_y = []
  54. for before,text,after,label,turn in zip(df["list_before"],df["list_text"],df["list_after"],df["list_label"],df["turn"]):
  55. before = str(before) if str(before)!="nan" else ""
  56. text = str(text)
  57. after = str(after) if str(after)!="nan" else ""
  58. the_label = None
  59. if math.isnan(turn):
  60. the_label = int(label)
  61. else:
  62. the_label = int(turn)
  63. if the_label not in [0,1,2]:
  64. print(after,text)
  65. continue
  66. x = encodeInput([before,text,after], word_len=50, word_flag=True,userFool=False)
  67. y = labeling(the_label)
  68. if np.random.random()<percent:
  69. train_x.append(x)
  70. train_y.append(y)
  71. else:
  72. test_x.append(x)
  73. test_y.append(y)
  74. return np.transpose(np.array(train_x),(1,0,2)),np.array(train_y),np.transpose(np.array(test_x),(1,0,2)),np.array(test_y)
  75. def train():
  76. pk_file = "traindata/all_data.pk"
  77. if os.path.exists(pk_file):
  78. train_x,train_y,test_x,test_y = load(pk_file)
  79. else:
  80. train_x,train_y,test_x,test_y = getTrainData()
  81. save([train_x,train_y,test_x,test_y],pk_file)
  82. #model_w = getModel_word()
  83. #vocab,weights = getVocabAndMatrix(model_w, Embedding_size=60)
  84. #model = getTextCNNModel(input_shape=(3,50), vocab=vocab, embedding_weights=weights, classes=6)
  85. '''
  86. with tf.Session(graph=tf.get_default_graph()) as sess:
  87. with sess.graph.as_default():
  88. '''
  89. with tf.Session() as sess:
  90. dict_key_value = load("dict_key_value.pk")
  91. vocab,matrix = getVocabAndMatrix(getModel_word())
  92. model = getBiLSTMModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=3)
  93. # model = getBiLSTMModel(input_shape=(3,50,60), vocab=fool_char_to_id.keys(), embedding_weights=dict_key_value["bert/embeddings/word_embeddings:0"], classes=3)
  94. # for k,v in dict_key_value.items():
  95. # if re.search("encoder",k) is not None:
  96. # sess.run(tf.assign(sess.graph.get_tensor_by_name(k[13:]),v))
  97. # print(k)
  98. # sess.run(tf.assign(sess.graph.get_tensor_by_name("char_embeding/embeddings:0"),dict_key_value["bert/embeddings/word_embeddings:0"]))
  99. #model = models.getTextCNNModel()
  100. print("loading weights")
  101. # model.load_weights("log/ep378-loss0.178-val_loss0.117-f1_score0.965.h5",by_name=True, skip_mismatch=True)
  102. callback = ModelCheckpoint(filepath="log/"+"ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}-f1_score{val_f1_score:.3f}.h5",monitor="val_loss",save_best_only=True, save_weights_only=True, mode="min")
  103. model.fit(x=[train_x[0],train_x[1],train_x[2]],y=train_y,batch_size=128,epochs=600,callbacks=[callback],validation_data=[[test_x[0],test_x[1],test_x[2]],test_y])
  104. def test():
  105. _span = [':预算金额1000000元,中标金额', '1df元', ';']
  106. input = encodeInput(_span, word_len=50, word_flag=True,userFool=True)
  107. print(input)
  108. graph=tf.get_default_graph()
  109. with graph.as_default():
  110. sess = tf.Session(graph=graph)
  111. with sess.as_default():
  112. dict_key_value = load("dict_key_value.pk")
  113. model = getBiLSTMModel(input_shape=(3,50,60), vocab=fool_char_to_id.keys(), embedding_weights=None, classes=3)
  114. for k,v in dict_key_value.items():
  115. if re.search("encoder",k) is not None:
  116. sess.run(tf.assign(sess.graph.get_tensor_by_name(k[13:]),v))
  117. print(k)
  118. model.load_weights("log/ep019-loss0.177-val_loss0.115-f1_score0.968.h5",by_name=True, skip_mismatch=True)
  119. print(model.predict([np.array([input[0]]),np.array([input[1]]),np.array([input[2]])]))
  120. def get_savedModel():
  121. sess = tf.Session(graph=tf.Graph())
  122. with sess.as_default():
  123. with sess.graph.as_default():
  124. vocab,matrix = getVocabAndMatrix(getModel_word())
  125. model = getBiLSTMModel(input_shape=(3,50,60), vocab=vocab, embedding_weights=matrix, classes=3)
  126. model.load_weights(filepath="log/ep009-loss0.057-val_loss0.076-f1_score0.978.h5")
  127. tf.saved_model.simple_save(session=sess,
  128. export_dir="../../dl/money/money_savedmodel",
  129. inputs={"input0":model.input[0],
  130. "input1":model.input[1],
  131. "input2":model.input[2]},
  132. outputs={"outputs":model.output})
  133. def tensorboard_model():
  134. with tf.Session(graph=tf.Graph()).as_default() as sess:
  135. with sess.graph.as_default():
  136. tf.saved_model.loader.load(sess, tags=["serve"], export_dir="../../dl/money/money_savedmodel1")
  137. tf.summary.FileWriter(graph=sess.graph,logdir="log2")
  138. if __name__=="__main__":
  139. train()
  140. # test()
  141. # get_savedModel()
  142. # tensorboard_model()