train.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/12/14 11:08
  5. #!/usr/bin/python3
  6. # -*- coding: utf-8 -*-
  7. # @Author : bidikeji
  8. # @Time : 2021/1/13 0013 10:12
  9. import os
  10. import re
  11. import time
  12. import logging
  13. logging.basicConfig(level=logging.DEBUG)
  14. print('准备导入tf')
  15. import tensorflow as tf
  16. print('准备导入np')
  17. import numpy as np
  18. print('准备导入Product_model')
  19. from BiddingKG.dl.product.product_model import Product_Model
  20. print('准备导入max_len')
  21. from BiddingKG.dl.product.data_tfrecord import max_len
  22. # from BiddingKG.dl.common.Utils import viterbi_decode
  23. print('准备设置CUDA环境')
  24. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  25. # max_len = 500
  26. batch_size = 256
  27. MIN_AFTER_DEQUEUE = batch_size*500
  28. def read_tfRecord(sess, file_list):
  29. # queue = tf.train.string_input_producer([file_tfRecord])
  30. queue = tf.train.string_input_producer(file_list)
  31. reader = tf.TFRecordReader()
  32. filename_, serialized_example = reader.read(queue)
  33. features = tf.parse_single_example(
  34. serialized_example,
  35. features = {
  36. "word_ids": tf.FixedLenFeature([max_len], tf.int64),
  37. "text_len":tf.FixedLenFeature([1], tf.int64),
  38. "tag_ids": tf.FixedLenFeature([max_len], tf.int64)
  39. }
  40. )
  41. word_ids = tf.cast(features['word_ids'], tf.int64)
  42. tag_ids = tf.cast(features['tag_ids'], tf.int64)
  43. text_len = tf.cast(features['text_len'], tf.int64)
  44. return text_len, word_ids , tag_ids
  45. def get_batch_record(sess,file_list, batch_size):
  46. text_len, word_ids, tag_ids = read_tfRecord(sess, file_list)
  47. text_len, word_ids, tag_ids = tf.train.shuffle_batch([text_len, word_ids , tag_ids],
  48. batch_size=batch_size,
  49. capacity=MIN_AFTER_DEQUEUE+batch_size*3,
  50. min_after_dequeue=MIN_AFTER_DEQUEUE,
  51. num_threads=8)
  52. text_len = tf.squeeze(text_len, squeeze_dims=1)
  53. return text_len, word_ids , tag_ids
  54. def total_sample(file_name):
  55. '''
  56. 返回样本总数
  57. :param file_name: tfrecord 文件名
  58. :return:
  59. '''
  60. sample_num = 0
  61. for record in tf.python_io.tf_record_iterator(file_name):
  62. sample_num += 1
  63. return sample_num
  64. def train():
  65. logging.info('进入main ')
  66. # filename = os.path.dirname(__file__)+'/data/train_data/maxlen_500_addunk_product_reason.tfrecode'
  67. # filename_dev = os.path.dirname(__file__)+'/data/test_data/maxlen_500_addunk_product_reason.tfrecode'
  68. # print('os.path.dirname(__file__): ', os.path.dirname(__file__))
  69. # print('filename path :', filename)
  70. # assert os.path.exists(filename)
  71. # assert os.path.exists(filename_dev)
  72. file_list = []
  73. file_list_dev = []
  74. train1 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-02-24_train1.tfrecode'
  75. train2 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-02-24_train2.tfrecode'
  76. dev1 = os.path.dirname(__file__)+'/data/test_data/ProductAndReason_2023-02-24_dev.tfrecode'
  77. # train1 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-03-30_remove_punctuation_train1.tfrecode'
  78. # train2 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-03-30_remove_punctuation_train2.tfrecode'
  79. # dev1 = os.path.dirname(__file__)+'/data/test_data/ProductAndReason_2023-03-30_remove_punctuation_dev.tfrecode'
  80. # print('filename path :', train1, os.path.exists(train1))
  81. file_list.append(train1)
  82. file_list.append(train2)
  83. file_list_dev.append(dev1)
  84. print('确保文件存在')
  85. print('filename path :', train1, os.path.exists(train1))
  86. # batch_size = 512
  87. # id_to_tag = {0: 'O', 1: 'B', 2: 'I', 3: 'E'}
  88. tag2index = {'S': 0, 'B-pro': 1, 'I-pro': 2, 'E-pro': 3, 'B-rea': 4, 'I-rea': 5, 'E-rea': 6}
  89. id_to_tag = {v:k for k,v in tag2index.items()}
  90. def get_ner(id_tag):
  91. ner = set()
  92. for it in re.finditer('12*3', id_tag):
  93. ner.add((it.start(), it.end(), 'pro'))
  94. for it in re.finditer('45*6', id_tag):
  95. ner.add((it.start(), it.end(), 'rea'))
  96. return ner
  97. print('准备构建模型')
  98. model = Product_Model()
  99. # saver = tf.train.Saver()
  100. print('模型构建完毕')
  101. with tf.Session() as sess:
  102. init_op = tf.global_variables_initializer()
  103. sess.run(init_op)
  104. print('参数初始化')
  105. text_len, word_ids, tag_ids = get_batch_record(sess, file_list, batch_size=batch_size)
  106. print('get_batch_record')
  107. text_len_dev, word_ids_dev, tag_ids_dev = get_batch_record(sess, file_list_dev, batch_size=batch_size)
  108. print('get_batch_record_dev')
  109. coord = tf.train.Coordinator()
  110. threads = tf.train.start_queue_runners(coord=coord)
  111. total_num = sum([total_sample(filename) for filename in file_list])
  112. logging.info('total_train_num: %d'%total_num)
  113. batch_num = total_num//batch_size
  114. batch_num_dev = sum([total_sample(filename_dev) for filename_dev in file_list_dev])//batch_size
  115. num = 0
  116. l = []
  117. max_f1 = 0.79
  118. # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch10_f10.6875_loss1.5230.ckpt'))
  119. # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch0_f10.7740_loss1.2526.ckpt'))
  120. model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch16_f10.8000_loss1.0775.ckpt'))
  121. print('模型加载成功')
  122. for epoch in range(20,50):
  123. t1 = time.time()
  124. for batch in range(batch_num):
  125. text_len_, word_ids_, tag_ids_ = sess.run([text_len, word_ids, tag_ids])
  126. # print(text_len_.shape, word_ids_.shape, tag_ids_.shape)
  127. loss_, train_op_, globel_step_ = sess.run([model.loss, model.train_op, model.globel_step],
  128. feed_dict={model.char_inputs:word_ids_,
  129. # model.lengths:text_len_,
  130. model.dropout:0.5,
  131. model.targets:tag_ids_})
  132. if batch % 100==0:
  133. logging.info('loss_:%.4f,\tglobel_step_: %d'%(loss_, globel_step_))
  134. print('耗时:', time.time()-t1)
  135. num += text_len_.shape[0]
  136. # print('已训练数:%d, 样本总数:%d'%(num, total_num))
  137. results = []
  138. trans = model.trans.eval()
  139. Precision = []
  140. Recall = []
  141. F1 = []
  142. loss = []
  143. pred_num = 0
  144. gold_num = 0
  145. equal_num = 0
  146. for batch in range(batch_num_dev):
  147. text_len_dev_, word_ids_dev_, tag_ids_dev_ = sess.run([text_len_dev, word_ids_dev, tag_ids_dev])
  148. lengths, scores, batch_loss = sess.run([model.lengths, model.logits, model.loss], feed_dict={model.char_inputs:word_ids_dev_,
  149. # model.lengths:text_len_dev_,
  150. model.dropout:0.5,
  151. model.targets:tag_ids_dev_})
  152. loss.append(batch_loss)
  153. batch_paths = model.decode(scores, lengths, trans)
  154. for i in range(len(tag_ids_dev_)):
  155. result = []
  156. gold = [str(x) for x in tag_ids_dev_[i][:lengths[i]]]
  157. pred = [str(x) for x in batch_paths[i][:lengths[i]]]
  158. gold_ner = get_ner("".join(gold))
  159. pred_ner = get_ner("".join(pred))
  160. # print('标签实体:',gold_ner)
  161. # print('预测实体:',pred_ner)
  162. pred_num += len(pred_ner)
  163. gold_num += len(gold_ner)
  164. equal_num += len(gold_ner & pred_ner)
  165. precision = equal_num / (pred_num + 1e-10)
  166. recall = equal_num / (gold_num + 1e-10)
  167. f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
  168. val_loss = np.mean(loss)
  169. logging.info('epoch: %d, f1:%.4f, acc:%.4f, recall:%.4f, val_loss:%.4f'%(epoch, f1, precision, recall, val_loss))
  170. if f1>max_f1:
  171. max_f1 = f1
  172. model.saver.save(sess, os.path.join(os.path.dirname(__file__)+'/model', "ner_epoch%d_f1%.4f_loss%.4f.ckpt"%(epoch,f1, val_loss)))
  173. logging.info('save model, max_f1:%.4f' %f1)
  174. coord.request_stop()
  175. coord.join(threads)
  176. if __name__ == "__main__":
  177. train()