#!/usr/bin/python3 # -*- coding: utf-8 -*- # @Author : bidikeji # @Time : 2021/12/14 11:08 #!/usr/bin/python3 # -*- coding: utf-8 -*- # @Author : bidikeji # @Time : 2021/1/13 0013 10:12 import os import re print('准备导入tf') import tensorflow as tf print('准备导入np') import numpy as np print('准备导入Product_model') from BiddingKG.dl.product.product_model import Product_Model print('准备导入max_len') from BiddingKG.dl.product.data_tfrecord import max_len # from BiddingKG.dl.common.Utils import viterbi_decode print('准备设置CUDA环境') os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # max_len = 500 def read_tfRecord(sess, file_tfRecord): queue = tf.train.string_input_producer([file_tfRecord]) reader = tf.TFRecordReader() filename_, serialized_example = reader.read(queue) features = tf.parse_single_example( serialized_example, features = { "word_ids": tf.FixedLenFeature([max_len], tf.int64), "text_len":tf.FixedLenFeature([1], tf.int64), "tag_ids": tf.FixedLenFeature([max_len], tf.int64) } ) word_ids = tf.cast(features['word_ids'], tf.int64) tag_ids = tf.cast(features['tag_ids'], tf.int64) text_len = tf.cast(features['text_len'], tf.int64) return text_len, word_ids , tag_ids def get_batch_record(sess,filename, batch_size): text_len, word_ids, tag_ids = read_tfRecord(sess, filename) text_len, word_ids, tag_ids = tf.train.shuffle_batch([text_len, word_ids , tag_ids], batch_size=batch_size, capacity=200+batch_size*3, min_after_dequeue=1, num_threads=5) text_len = tf.squeeze(text_len, squeeze_dims=1) return text_len, word_ids , tag_ids def total_sample(file_name): ''' 返回样本总数 :param file_name: tfrecord 文件名 :return: ''' sample_num = 0 for record in tf.python_io.tf_record_iterator(file_name): sample_num += 1 return sample_num if __name__ == "__main__": print('进入main ') filename = os.path.dirname(__file__)+'/data/train_data/maxlen_500_addunk_product_reason.tfrecode' filename_dev = os.path.dirname(__file__)+'/data/test_data/maxlen_500_addunk_product_reason.tfrecode' assert os.path.exists(filename) assert os.path.exists(filename_dev) print('确保文件存在') batch_size = 100 # id_to_tag = {0: 'O', 1: 'B', 2: 'I', 3: 'E'} tag2index = {'S': 0, 'B-pro': 1, 'I-pro': 2, 'E-pro': 3, 'B-rea': 4, 'I-rea': 5, 'E-rea': 6} id_to_tag = {v:k for k,v in tag2index.items()} def get_ner(id_tag): ner = set() for it in re.finditer('12*3', id_tag): ner.add((it.start(), it.end(), 'pro')) for it in re.finditer('45*6', id_tag): ner.add((it.start(), it.end(), 'rea')) return ner print('准备构建模型') model = Product_Model() # saver = tf.train.Saver() print('模型构建完毕') with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) print('参数初始化') text_len, word_ids, tag_ids = get_batch_record(sess, filename, batch_size=batch_size) print('get_batch_record') text_len_dev, word_ids_dev, tag_ids_dev = get_batch_record(sess, filename_dev, batch_size=batch_size) print('get_batch_record_dev') coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) print('total_sample(filename)', total_sample(filename)) total_num = total_sample(filename) batch_num = total_num//batch_size batch_num_dev = total_sample(filename_dev)//batch_size num = 0 l = [] max_f1 = 0 # model.saver.restore(sess, os.path.join(os.path.dirname(__file__)+'/model','ner_epoch10_f10.6875_loss1.5230.ckpt')) # print('模型加载成功') for epoch in range(50): for batch in range(batch_num): text_len_, word_ids_, tag_ids_ = sess.run([text_len, word_ids, tag_ids]) # print(text_len_.shape, word_ids_.shape, tag_ids_.shape) loss_, train_op_, globel_step_ = sess.run([model.loss, model.train_op, model.globel_step], feed_dict={model.char_inputs:word_ids_, # model.lengths:text_len_, model.dropout:0.5, model.targets:tag_ids_}) if batch % 100==0: print('loss_:', loss_, '\tglobel_step_:',globel_step_) num += text_len_.shape[0] print('训练总数:%d, 样本总数:%d'%(num, total_num)) results = [] trans = model.trans.eval() Precision = [] Recall = [] F1 = [] loss = [] pred_num = 0 gold_num = 0 equal_num = 0 for batch in range(batch_num_dev): text_len_dev_, word_ids_dev_, tag_ids_dev_ = sess.run([text_len_dev, word_ids_dev, tag_ids_dev]) lengths, scores, batch_loss = sess.run([model.lengths, model.logits, model.loss], feed_dict={model.char_inputs:word_ids_dev_, # model.lengths:text_len_dev_, model.dropout:0.5, model.targets:tag_ids_dev_}) loss.append(batch_loss) batch_paths = model.decode(scores, lengths, trans) for i in range(len(tag_ids_dev_)): result = [] gold = [str(x) for x in tag_ids_dev_[i][:lengths[i]]] pred = [str(x) for x in batch_paths[i][:lengths[i]]] gold_ner = get_ner("".join(gold)) pred_ner = get_ner("".join(pred)) # print('标签实体:',gold_ner) # print('预测实体:',pred_ner) pred_num += len(pred_ner) gold_num += len(gold_ner) equal_num += len(gold_ner & pred_ner) precision = equal_num / (pred_num + 1e-10) recall = equal_num / (gold_num + 1e-10) f1 = 2 * (precision * recall) / (precision + recall + 1e-10) val_loss = np.mean(loss) print('epoch: %d, f1:%.4f, acc:%.4f, recall:%.4f, val_loss:%.4f'%(epoch, f1, precision, recall, val_loss)) if f1>max_f1: max_f1 = f1 model.saver.save(sess, os.path.join(os.path.dirname(__file__)+'/model', "ner_epoch%d_f1%.4f_loss%.4f.ckpt"%(epoch,f1, val_loss))) print('save model, max_f1:%.4f' %f1) coord.request_stop() coord.join(threads)