123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- #!/usr/bin/python3
- # -*- coding: utf-8 -*-
- # @Author : bidikeji
- # @Time : 2021/12/14 11:08
- #!/usr/bin/python3
- # -*- coding: utf-8 -*-
- # @Author : bidikeji
- # @Time : 2021/1/13 0013 10:12
- import os
- import re
- import time
- import logging
- logging.basicConfig(level=logging.DEBUG)
- print('准备导入tf')
- import tensorflow as tf
- print('准备导入np')
- import numpy as np
- print('准备导入Product_model')
- from BiddingKG.dl.product.product_model import Product_Model
- print('准备导入max_len')
- from BiddingKG.dl.product.data_tfrecord import max_len
- # from BiddingKG.dl.common.Utils import viterbi_decode
- print('准备设置CUDA环境')
- os.environ['CUDA_VISIBLE_DEVICES'] = '0'
- # max_len = 500
- batch_size = 256
- MIN_AFTER_DEQUEUE = batch_size*500
- def read_tfRecord(sess, file_list):
- # queue = tf.train.string_input_producer([file_tfRecord])
- queue = tf.train.string_input_producer(file_list)
- reader = tf.TFRecordReader()
- filename_, serialized_example = reader.read(queue)
- features = tf.parse_single_example(
- serialized_example,
- features = {
- "word_ids": tf.FixedLenFeature([max_len], tf.int64),
- "text_len":tf.FixedLenFeature([1], tf.int64),
- "tag_ids": tf.FixedLenFeature([max_len], tf.int64)
- }
- )
- word_ids = tf.cast(features['word_ids'], tf.int64)
- tag_ids = tf.cast(features['tag_ids'], tf.int64)
- text_len = tf.cast(features['text_len'], tf.int64)
- return text_len, word_ids , tag_ids
- def get_batch_record(sess,file_list, batch_size):
- text_len, word_ids, tag_ids = read_tfRecord(sess, file_list)
- text_len, word_ids, tag_ids = tf.train.shuffle_batch([text_len, word_ids , tag_ids],
- batch_size=batch_size,
- capacity=MIN_AFTER_DEQUEUE+batch_size*3,
- min_after_dequeue=MIN_AFTER_DEQUEUE,
- num_threads=8)
- text_len = tf.squeeze(text_len, squeeze_dims=1)
- return text_len, word_ids , tag_ids
- def total_sample(file_name):
- '''
- 返回样本总数
- :param file_name: tfrecord 文件名
- :return:
- '''
- sample_num = 0
- for record in tf.python_io.tf_record_iterator(file_name):
- sample_num += 1
- return sample_num
- def train():
- logging.info('进入main ')
- # filename = os.path.dirname(__file__)+'/data/train_data/maxlen_500_addunk_product_reason.tfrecode'
- # filename_dev = os.path.dirname(__file__)+'/data/test_data/maxlen_500_addunk_product_reason.tfrecode'
- # print('os.path.dirname(__file__): ', os.path.dirname(__file__))
- # print('filename path :', filename)
- # assert os.path.exists(filename)
- # assert os.path.exists(filename_dev)
- file_list = []
- file_list_dev = []
- train1 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-02-24_train1.tfrecode'
- train2 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-02-24_train2.tfrecode'
- dev1 = os.path.dirname(__file__)+'/data/test_data/ProductAndReason_2023-02-24_dev.tfrecode'
- # train1 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-03-30_remove_punctuation_train1.tfrecode'
- # train2 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-03-30_remove_punctuation_train2.tfrecode'
- # dev1 = os.path.dirname(__file__)+'/data/test_data/ProductAndReason_2023-03-30_remove_punctuation_dev.tfrecode'
- # print('filename path :', train1, os.path.exists(train1))
- file_list.append(train1)
- file_list.append(train2)
- file_list_dev.append(dev1)
- print('确保文件存在')
- print('filename path :', train1, os.path.exists(train1))
- # batch_size = 512
- # id_to_tag = {0: 'O', 1: 'B', 2: 'I', 3: 'E'}
- tag2index = {'S': 0, 'B-pro': 1, 'I-pro': 2, 'E-pro': 3, 'B-rea': 4, 'I-rea': 5, 'E-rea': 6}
- id_to_tag = {v:k for k,v in tag2index.items()}
- def get_ner(id_tag):
- ner = set()
- for it in re.finditer('12*3', id_tag):
- ner.add((it.start(), it.end(), 'pro'))
- for it in re.finditer('45*6', id_tag):
- ner.add((it.start(), it.end(), 'rea'))
- return ner
- print('准备构建模型')
- model = Product_Model()
- # saver = tf.train.Saver()
- print('模型构建完毕')
- with tf.Session() as sess:
- init_op = tf.global_variables_initializer()
- sess.run(init_op)
- print('参数初始化')
- text_len, word_ids, tag_ids = get_batch_record(sess, file_list, batch_size=batch_size)
- print('get_batch_record')
- text_len_dev, word_ids_dev, tag_ids_dev = get_batch_record(sess, file_list_dev, batch_size=batch_size)
- print('get_batch_record_dev')
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(coord=coord)
- total_num = sum([total_sample(filename) for filename in file_list])
- logging.info('total_train_num: %d'%total_num)
- batch_num = total_num//batch_size
- batch_num_dev = sum([total_sample(filename_dev) for filename_dev in file_list_dev])//batch_size
- num = 0
- l = []
- max_f1 = 0.79
- # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch10_f10.6875_loss1.5230.ckpt'))
- # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch0_f10.7740_loss1.2526.ckpt'))
- model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch16_f10.8000_loss1.0775.ckpt'))
- print('模型加载成功')
- for epoch in range(20,50):
- t1 = time.time()
- for batch in range(batch_num):
- text_len_, word_ids_, tag_ids_ = sess.run([text_len, word_ids, tag_ids])
- # print(text_len_.shape, word_ids_.shape, tag_ids_.shape)
- loss_, train_op_, globel_step_ = sess.run([model.loss, model.train_op, model.globel_step],
- feed_dict={model.char_inputs:word_ids_,
- # model.lengths:text_len_,
- model.dropout:0.5,
- model.targets:tag_ids_})
- if batch % 100==0:
- logging.info('loss_:%.4f,\tglobel_step_: %d'%(loss_, globel_step_))
- print('耗时:', time.time()-t1)
- num += text_len_.shape[0]
- # print('已训练数:%d, 样本总数:%d'%(num, total_num))
- results = []
- trans = model.trans.eval()
- Precision = []
- Recall = []
- F1 = []
- loss = []
- pred_num = 0
- gold_num = 0
- equal_num = 0
- for batch in range(batch_num_dev):
- text_len_dev_, word_ids_dev_, tag_ids_dev_ = sess.run([text_len_dev, word_ids_dev, tag_ids_dev])
- lengths, scores, batch_loss = sess.run([model.lengths, model.logits, model.loss], feed_dict={model.char_inputs:word_ids_dev_,
- # model.lengths:text_len_dev_,
- model.dropout:0.5,
- model.targets:tag_ids_dev_})
- loss.append(batch_loss)
- batch_paths = model.decode(scores, lengths, trans)
- for i in range(len(tag_ids_dev_)):
- result = []
- gold = [str(x) for x in tag_ids_dev_[i][:lengths[i]]]
- pred = [str(x) for x in batch_paths[i][:lengths[i]]]
- gold_ner = get_ner("".join(gold))
- pred_ner = get_ner("".join(pred))
- # print('标签实体:',gold_ner)
- # print('预测实体:',pred_ner)
- pred_num += len(pred_ner)
- gold_num += len(gold_ner)
- equal_num += len(gold_ner & pred_ner)
- precision = equal_num / (pred_num + 1e-10)
- recall = equal_num / (gold_num + 1e-10)
- f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
- val_loss = np.mean(loss)
- logging.info('epoch: %d, f1:%.4f, acc:%.4f, recall:%.4f, val_loss:%.4f'%(epoch, f1, precision, recall, val_loss))
- if f1>max_f1:
- max_f1 = f1
- model.saver.save(sess, os.path.join(os.path.dirname(__file__)+'/model', "ner_epoch%d_f1%.4f_loss%.4f.ckpt"%(epoch,f1, val_loss)))
- logging.info('save model, max_f1:%.4f' %f1)
- coord.request_stop()
- coord.join(threads)
- if __name__ == "__main__":
- train()
|