#!/usr/bin/python3 # -*- coding: utf-8 -*- # @Author : bidikeji # @Time : 2021/12/14 11:08 #!/usr/bin/python3 # -*- coding: utf-8 -*- # @Author : bidikeji # @Time : 2021/1/13 0013 10:12 import os import re import time import logging logging.basicConfig(level=logging.DEBUG) print('准备导入tf') import tensorflow as tf print('准备导入np') import numpy as np print('准备导入Product_model') from BiddingKG.dl.product.product_model import Product_Model print('准备导入max_len') from BiddingKG.dl.product.data_tfrecord import max_len # from BiddingKG.dl.common.Utils import viterbi_decode print('准备设置CUDA环境') os.environ['CUDA_VISIBLE_DEVICES'] = '0' # max_len = 500 batch_size = 256 MIN_AFTER_DEQUEUE = batch_size*500 def read_tfRecord(sess, file_list): # queue = tf.train.string_input_producer([file_tfRecord]) queue = tf.train.string_input_producer(file_list) reader = tf.TFRecordReader() filename_, serialized_example = reader.read(queue) features = tf.parse_single_example( serialized_example, features = { "word_ids": tf.FixedLenFeature([max_len], tf.int64), "text_len":tf.FixedLenFeature([1], tf.int64), "tag_ids": tf.FixedLenFeature([max_len], tf.int64) } ) word_ids = tf.cast(features['word_ids'], tf.int64) tag_ids = tf.cast(features['tag_ids'], tf.int64) text_len = tf.cast(features['text_len'], tf.int64) return text_len, word_ids , tag_ids def get_batch_record(sess,file_list, batch_size): text_len, word_ids, tag_ids = read_tfRecord(sess, file_list) text_len, word_ids, tag_ids = tf.train.shuffle_batch([text_len, word_ids , tag_ids], batch_size=batch_size, capacity=MIN_AFTER_DEQUEUE+batch_size*3, min_after_dequeue=MIN_AFTER_DEQUEUE, num_threads=8) text_len = tf.squeeze(text_len, squeeze_dims=1) return text_len, word_ids , tag_ids def total_sample(file_name): ''' 返回样本总数 :param file_name: tfrecord 文件名 :return: ''' sample_num = 0 for record in tf.python_io.tf_record_iterator(file_name): sample_num += 1 return sample_num def train(): logging.info('进入main ') # filename = os.path.dirname(__file__)+'/data/train_data/maxlen_500_addunk_product_reason.tfrecode' # filename_dev = os.path.dirname(__file__)+'/data/test_data/maxlen_500_addunk_product_reason.tfrecode' # print('os.path.dirname(__file__): ', os.path.dirname(__file__)) # print('filename path :', filename) # assert os.path.exists(filename) # assert os.path.exists(filename_dev) file_list = [] file_list_dev = [] train1 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-02-24_train1.tfrecode' train2 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-02-24_train2.tfrecode' dev1 = os.path.dirname(__file__)+'/data/test_data/ProductAndReason_2023-02-24_dev.tfrecode' # train1 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-03-30_remove_punctuation_train1.tfrecode' # train2 = os.path.dirname(__file__)+'/data/train_data/ProductAndReason_2023-03-30_remove_punctuation_train2.tfrecode' # dev1 = os.path.dirname(__file__)+'/data/test_data/ProductAndReason_2023-03-30_remove_punctuation_dev.tfrecode' # print('filename path :', train1, os.path.exists(train1)) file_list.append(train1) file_list.append(train2) file_list_dev.append(dev1) print('确保文件存在') print('filename path :', train1, os.path.exists(train1)) # batch_size = 512 # id_to_tag = {0: 'O', 1: 'B', 2: 'I', 3: 'E'} tag2index = {'S': 0, 'B-pro': 1, 'I-pro': 2, 'E-pro': 3, 'B-rea': 4, 'I-rea': 5, 'E-rea': 6} id_to_tag = {v:k for k,v in tag2index.items()} def get_ner(id_tag): ner = set() for it in re.finditer('12*3', id_tag): ner.add((it.start(), it.end(), 'pro')) for it in re.finditer('45*6', id_tag): ner.add((it.start(), it.end(), 'rea')) return ner print('准备构建模型') model = Product_Model() # saver = tf.train.Saver() print('模型构建完毕') with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) print('参数初始化') text_len, word_ids, tag_ids = get_batch_record(sess, file_list, batch_size=batch_size) print('get_batch_record') text_len_dev, word_ids_dev, tag_ids_dev = get_batch_record(sess, file_list_dev, batch_size=batch_size) print('get_batch_record_dev') coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) total_num = sum([total_sample(filename) for filename in file_list]) logging.info('total_train_num: %d'%total_num) batch_num = total_num//batch_size batch_num_dev = sum([total_sample(filename_dev) for filename_dev in file_list_dev])//batch_size num = 0 l = [] max_f1 = 0.79 # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch10_f10.6875_loss1.5230.ckpt')) # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch0_f10.7740_loss1.2526.ckpt')) model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch16_f10.8000_loss1.0775.ckpt')) print('模型加载成功') for epoch in range(20,50): t1 = time.time() for batch in range(batch_num): text_len_, word_ids_, tag_ids_ = sess.run([text_len, word_ids, tag_ids]) # print(text_len_.shape, word_ids_.shape, tag_ids_.shape) loss_, train_op_, globel_step_ = sess.run([model.loss, model.train_op, model.globel_step], feed_dict={model.char_inputs:word_ids_, # model.lengths:text_len_, model.dropout:0.5, model.targets:tag_ids_}) if batch % 100==0: logging.info('loss_:%.4f,\tglobel_step_: %d'%(loss_, globel_step_)) print('耗时:', time.time()-t1) num += text_len_.shape[0] # print('已训练数:%d, 样本总数:%d'%(num, total_num)) results = [] trans = model.trans.eval() Precision = [] Recall = [] F1 = [] loss = [] pred_num = 0 gold_num = 0 equal_num = 0 for batch in range(batch_num_dev): text_len_dev_, word_ids_dev_, tag_ids_dev_ = sess.run([text_len_dev, word_ids_dev, tag_ids_dev]) lengths, scores, batch_loss = sess.run([model.lengths, model.logits, model.loss], feed_dict={model.char_inputs:word_ids_dev_, # model.lengths:text_len_dev_, model.dropout:0.5, model.targets:tag_ids_dev_}) loss.append(batch_loss) batch_paths = model.decode(scores, lengths, trans) for i in range(len(tag_ids_dev_)): result = [] gold = [str(x) for x in tag_ids_dev_[i][:lengths[i]]] pred = [str(x) for x in batch_paths[i][:lengths[i]]] gold_ner = get_ner("".join(gold)) pred_ner = get_ner("".join(pred)) # print('标签实体:',gold_ner) # print('预测实体:',pred_ner) pred_num += len(pred_ner) gold_num += len(gold_ner) equal_num += len(gold_ner & pred_ner) precision = equal_num / (pred_num + 1e-10) recall = equal_num / (gold_num + 1e-10) f1 = 2 * (precision * recall) / (precision + recall + 1e-10) val_loss = np.mean(loss) logging.info('epoch: %d, f1:%.4f, acc:%.4f, recall:%.4f, val_loss:%.4f'%(epoch, f1, precision, recall, val_loss)) if f1>max_f1: max_f1 = f1 model.saver.save(sess, os.path.join(os.path.dirname(__file__)+'/model', "ner_epoch%d_f1%.4f_loss%.4f.ckpt"%(epoch,f1, val_loss))) logging.info('save model, max_f1:%.4f' %f1) coord.request_stop() coord.join(threads) if __name__ == "__main__": train()