train.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/12/14 11:08
  5. #!/usr/bin/python3
  6. # -*- coding: utf-8 -*-
  7. # @Author : bidikeji
  8. # @Time : 2021/1/13 0013 10:12
  9. import os
  10. import re
  11. print('准备导入tf')
  12. import tensorflow as tf
  13. print('准备导入np')
  14. import numpy as np
  15. print('准备导入Product_model')
  16. from BiddingKG.dl.product.product_model import Product_Model
  17. print('准备导入max_len')
  18. from BiddingKG.dl.product.data_tfrecord import max_len
  19. # from BiddingKG.dl.common.Utils import viterbi_decode
  20. print('准备设置CUDA环境')
  21. os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
  22. # max_len = 500
  23. def read_tfRecord(sess, file_tfRecord):
  24. queue = tf.train.string_input_producer([file_tfRecord])
  25. reader = tf.TFRecordReader()
  26. filename_, serialized_example = reader.read(queue)
  27. features = tf.parse_single_example(
  28. serialized_example,
  29. features = {
  30. "word_ids": tf.FixedLenFeature([max_len], tf.int64),
  31. "text_len":tf.FixedLenFeature([1], tf.int64),
  32. "tag_ids": tf.FixedLenFeature([max_len], tf.int64)
  33. }
  34. )
  35. word_ids = tf.cast(features['word_ids'], tf.int64)
  36. tag_ids = tf.cast(features['tag_ids'], tf.int64)
  37. text_len = tf.cast(features['text_len'], tf.int64)
  38. return text_len, word_ids , tag_ids
  39. def get_batch_record(sess,filename, batch_size):
  40. text_len, word_ids, tag_ids = read_tfRecord(sess, filename)
  41. text_len, word_ids, tag_ids = tf.train.shuffle_batch([text_len, word_ids , tag_ids],
  42. batch_size=batch_size,
  43. capacity=200+batch_size*3,
  44. min_after_dequeue=1,
  45. num_threads=5)
  46. text_len = tf.squeeze(text_len, squeeze_dims=1)
  47. return text_len, word_ids , tag_ids
  48. def total_sample(file_name):
  49. '''
  50. 返回样本总数
  51. :param file_name: tfrecord 文件名
  52. :return:
  53. '''
  54. sample_num = 0
  55. for record in tf.python_io.tf_record_iterator(file_name):
  56. sample_num += 1
  57. return sample_num
  58. if __name__ == "__main__":
  59. print('进入main ')
  60. filename = os.path.dirname(__file__)+'/data/train_data/maxlen_500_addunk_product_reason.tfrecode'
  61. filename_dev = os.path.dirname(__file__)+'/data/test_data/maxlen_500_addunk_product_reason.tfrecode'
  62. assert os.path.exists(filename)
  63. assert os.path.exists(filename_dev)
  64. print('确保文件存在')
  65. batch_size = 100
  66. # id_to_tag = {0: 'O', 1: 'B', 2: 'I', 3: 'E'}
  67. tag2index = {'S': 0, 'B-pro': 1, 'I-pro': 2, 'E-pro': 3, 'B-rea': 4, 'I-rea': 5, 'E-rea': 6}
  68. id_to_tag = {v:k for k,v in tag2index.items()}
  69. def get_ner(id_tag):
  70. ner = set()
  71. for it in re.finditer('12*3', id_tag):
  72. ner.add((it.start(), it.end(), 'pro'))
  73. for it in re.finditer('45*6', id_tag):
  74. ner.add((it.start(), it.end(), 'rea'))
  75. return ner
  76. print('准备构建模型')
  77. model = Product_Model()
  78. # saver = tf.train.Saver()
  79. print('模型构建完毕')
  80. with tf.Session() as sess:
  81. init_op = tf.global_variables_initializer()
  82. sess.run(init_op)
  83. print('参数初始化')
  84. text_len, word_ids, tag_ids = get_batch_record(sess, filename, batch_size=batch_size)
  85. print('get_batch_record')
  86. text_len_dev, word_ids_dev, tag_ids_dev = get_batch_record(sess, filename_dev, batch_size=batch_size)
  87. print('get_batch_record_dev')
  88. coord = tf.train.Coordinator()
  89. threads = tf.train.start_queue_runners(coord=coord)
  90. print('total_sample(filename)', total_sample(filename))
  91. total_num = total_sample(filename)
  92. batch_num = total_num//batch_size
  93. batch_num_dev = total_sample(filename_dev)//batch_size
  94. num = 0
  95. l = []
  96. max_f1 = 0
  97. # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch10_f10.6875_loss1.5230.ckpt'))
  98. # print('模型加载成功')
  99. for epoch in range(50):
  100. for batch in range(batch_num):
  101. text_len_, word_ids_, tag_ids_ = sess.run([text_len, word_ids, tag_ids])
  102. # print(text_len_.shape, word_ids_.shape, tag_ids_.shape)
  103. loss_, train_op_, globel_step_ = sess.run([model.loss, model.train_op, model.globel_step],
  104. feed_dict={model.char_inputs:word_ids_,
  105. # model.lengths:text_len_,
  106. model.dropout:0.5,
  107. model.targets:tag_ids_})
  108. if batch % 100==0:
  109. print('loss_:', loss_, '\tglobel_step_:',globel_step_)
  110. num += text_len_.shape[0]
  111. print('训练总数:%d, 样本总数:%d'%(num, total_num))
  112. results = []
  113. trans = model.trans.eval()
  114. Precision = []
  115. Recall = []
  116. F1 = []
  117. loss = []
  118. pred_num = 0
  119. gold_num = 0
  120. equal_num = 0
  121. for batch in range(batch_num_dev):
  122. text_len_dev_, word_ids_dev_, tag_ids_dev_ = sess.run([text_len_dev, word_ids_dev, tag_ids_dev])
  123. lengths, scores, batch_loss = sess.run([model.lengths, model.logits, model.loss], feed_dict={model.char_inputs:word_ids_dev_,
  124. # model.lengths:text_len_dev_,
  125. model.dropout:0.5,
  126. model.targets:tag_ids_dev_})
  127. loss.append(batch_loss)
  128. batch_paths = model.decode(scores, lengths, trans)
  129. for i in range(len(tag_ids_dev_)):
  130. result = []
  131. gold = [str(x) for x in tag_ids_dev_[i][:lengths[i]]]
  132. pred = [str(x) for x in batch_paths[i][:lengths[i]]]
  133. gold_ner = get_ner("".join(gold))
  134. pred_ner = get_ner("".join(pred))
  135. # print('标签实体:',gold_ner)
  136. # print('预测实体:',pred_ner)
  137. pred_num += len(pred_ner)
  138. gold_num += len(gold_ner)
  139. equal_num += len(gold_ner & pred_ner)
  140. precision = equal_num / (pred_num + 1e-10)
  141. recall = equal_num / (gold_num + 1e-10)
  142. f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
  143. val_loss = np.mean(loss)
  144. print('epoch: %d, f1:%.4f, acc:%.4f, recall:%.4f, val_loss:%.4f'%(epoch, f1, precision, recall, val_loss))
  145. if f1>max_f1:
  146. max_f1 = f1
  147. model.saver.save(sess, os.path.join(os.path.dirname(__file__)+'/model', "ner_epoch%d_f1%.4f_loss%.4f.ckpt"%(epoch,f1, val_loss)))
  148. print('save model, max_f1:%.4f' %f1)
  149. coord.request_stop()
  150. coord.join(threads)