main.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/1/13 0013 14:03
  5. from BiddingKG.dl.product.product_model import Product_Model
  6. from BiddingKG.dl.product.data_util import BatchManager, get_label_data, id_to_tag, input_from_line, decode, result_to_json
  7. import numpy as np
  8. import tensorflow as tf
  9. import random
  10. import pickle
  11. import os
  12. def train():
  13. # all_data = get_label_data()
  14. # random.shuffle(all_data)
  15. # train_data = all_data[:int(len(all_data)*0.85)]
  16. # dev_data = all_data[int(len(all_data)*0.85):]
  17. # with open('data/train_data2.pkl', 'wb') as f:
  18. # pickle.dump(train_data, f)
  19. # with open('data/dev_data2.pkl', 'wb') as f:
  20. # pickle.dump(dev_data, f)
  21. with open('data/train_data2.pkl', 'rb') as f:
  22. train_data = pickle.load(f)
  23. with open('data/dev_data2.pkl', 'rb') as f:
  24. dev_data = pickle.load(f)
  25. train_manager = BatchManager(train_data, batch_size=128)
  26. dev_manager = BatchManager(dev_data, batch_size=64)
  27. tf_config = tf.ConfigProto()
  28. tf_config.gpu_options.allow_growth = True
  29. steps_per_epoch = train_manager.len_data
  30. ckpt_path = "model"
  31. with tf.Session(config=tf_config) as sess:
  32. model = Product_Model()
  33. sess.run(tf.global_variables_initializer())
  34. # ckpt = tf.train.get_checkpoint_state(ckpt_path)
  35. # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  36. # model.saver.restore(sess, ckpt.model_checkpoint_path)
  37. # print("从文件加载原来模型数据",ckpt.model_checkpoint_path)
  38. print('准备训练数据')
  39. loss = []
  40. mix_loss = 1000
  41. max_f1 = 0
  42. for i in range(100):
  43. print('epochs:',i)
  44. # model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  45. # break
  46. for batch in train_manager.iter_batch(shuffle=True):
  47. # print('batch:',len(batch))
  48. # step, batch_loss = model.run_step(sess, True, batch)
  49. step, batch_loss = model.run_step(sess, 'train', batch)
  50. loss.append(batch_loss)
  51. if step % 10 == 0:
  52. iteration = step // steps_per_epoch + 1
  53. print('iter:{} step:{} loss:{}'.format(iteration, step, np.mean(loss)))
  54. if i >= 50 or i%5==0:
  55. f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  56. print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))
  57. if max_f1 < f1:
  58. model.saver.save(sess, os.path.join(ckpt_path, "ner2.ckpt"))
  59. print("model save .bast f1 is %.4f" % f1)
  60. max_f1 = f1
  61. # if np.mean(loss)<mix_loss:
  62. # mix_loss = np.mean(loss)
  63. # model.saver.save(sess, os.path.join(ckpt_path, "ner.ckpt"))
  64. # print("model saved, loss is:",mix_loss)
  65. loss = []
  66. def evaluate_line():
  67. ckpt_path = "model"
  68. with tf.Session() as sess:
  69. model = Product_Model()
  70. sess.run(tf.global_variables_initializer())
  71. ckpt = tf.train.get_checkpoint_state(ckpt_path)
  72. if ckpt and tf.train.checkpoint_exists(ckpt_path):
  73. print('模型文件:',ckpt.model_checkpoint_path)
  74. model.saver.restore(sess, ckpt.model_checkpoint_path)
  75. print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
  76. while True:
  77. line = input("请输入测试句子:")
  78. result = model.evaluate_line(sess, line)
  79. print(result)
  80. def predict():
  81. pb_path = "model/product.pb"
  82. with tf.Graph().as_default():
  83. output_graph_def = tf.GraphDef()
  84. with open(pb_path, 'rb') as f:
  85. output_graph_def.ParseFromString(f.read())
  86. tf.import_graph_def(output_graph_def, name='') # 注意这里不能加名字
  87. with tf.Session() as sess:
  88. sess.run(tf.global_variables_initializer())
  89. for node in output_graph_def.node:
  90. print(node.name)
  91. char_input = sess.graph.get_tensor_by_name("CharInputs:0")
  92. length = sess.graph.get_tensor_by_name("Sum:0")
  93. dropout = sess.graph.get_tensor_by_name("Dropout:0")
  94. logit = sess.graph.get_tensor_by_name("logits/Reshape:0")
  95. tran = sess.graph.get_tensor_by_name("crf_loss/transitions:0")
  96. while True:
  97. line = input("请输入测试句子:")
  98. _, chars, tags = input_from_line(line)
  99. print(chars)
  100. lengths, scores, tran_ = sess.run([length,logit,tran],feed_dict={char_input:np.asarray(chars),
  101. dropout:1.0
  102. } )
  103. batch_paths = decode(scores, lengths, tran_)
  104. tags = batch_paths[0] # batch_paths[0][:lengths] 错误
  105. result = result_to_json(line, tags)
  106. print(result)
  107. if __name__ == "__main__":
  108. # train()
  109. # evaluate_line()
  110. predict()