123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- #!/usr/bin/python3
- # -*- coding: utf-8 -*-
- # @Author : bidikeji
- # @Time : 2021/1/13 0013 14:03
- from BiddingKG.dl.product.product_model import Product_Model
- from BiddingKG.dl.product.data_util import BatchManager, get_label_data, id_to_tag, input_from_line, decode, result_to_json
- import numpy as np
- import tensorflow as tf
- import random
- import pickle
- import os
- def train():
- # all_data = get_label_data()
- # random.shuffle(all_data)
- # train_data = all_data[:int(len(all_data)*0.85)]
- # dev_data = all_data[int(len(all_data)*0.85):]
- # with open('data/train_data2.pkl', 'wb') as f:
- # pickle.dump(train_data, f)
- # with open('data/dev_data2.pkl', 'wb') as f:
- # pickle.dump(dev_data, f)
- with open('data/train_data2.pkl', 'rb') as f:
- train_data = pickle.load(f)
- with open('data/dev_data2.pkl', 'rb') as f:
- dev_data = pickle.load(f)
- train_manager = BatchManager(train_data, batch_size=128)
- dev_manager = BatchManager(dev_data, batch_size=64)
- tf_config = tf.ConfigProto()
- tf_config.gpu_options.allow_growth = True
- steps_per_epoch = train_manager.len_data
- ckpt_path = "model"
- with tf.Session(config=tf_config) as sess:
- model = Product_Model()
- sess.run(tf.global_variables_initializer())
- # ckpt = tf.train.get_checkpoint_state(ckpt_path)
- # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
- # model.saver.restore(sess, ckpt.model_checkpoint_path)
- # print("从文件加载原来模型数据",ckpt.model_checkpoint_path)
- print('准备训练数据')
- loss = []
- mix_loss = 1000
- max_f1 = 0
- for i in range(100):
- print('epochs:',i)
- # model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
- # break
- for batch in train_manager.iter_batch(shuffle=True):
- # print('batch:',len(batch))
- # step, batch_loss = model.run_step(sess, True, batch)
- step, batch_loss = model.run_step(sess, 'train', batch)
- loss.append(batch_loss)
- if step % 10 == 0:
- iteration = step // steps_per_epoch + 1
- print('iter:{} step:{} loss:{}'.format(iteration, step, np.mean(loss)))
- if i >= 50 or i%5==0:
- f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
- print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))
- if max_f1 < f1:
- model.saver.save(sess, os.path.join(ckpt_path, "ner2.ckpt"))
- print("model save .bast f1 is %.4f" % f1)
- max_f1 = f1
- # if np.mean(loss)<mix_loss:
- # mix_loss = np.mean(loss)
- # model.saver.save(sess, os.path.join(ckpt_path, "ner.ckpt"))
- # print("model saved, loss is:",mix_loss)
- loss = []
- def evaluate_line():
- ckpt_path = "model"
- with tf.Session() as sess:
- model = Product_Model()
- sess.run(tf.global_variables_initializer())
- ckpt = tf.train.get_checkpoint_state(ckpt_path)
- if ckpt and tf.train.checkpoint_exists(ckpt_path):
- print('模型文件:',ckpt.model_checkpoint_path)
- model.saver.restore(sess, ckpt.model_checkpoint_path)
- print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
- while True:
- line = input("请输入测试句子:")
- result = model.evaluate_line(sess, line)
- print(result)
- def predict():
- pb_path = "model/product.pb"
- with tf.Graph().as_default():
- output_graph_def = tf.GraphDef()
- with open(pb_path, 'rb') as f:
- output_graph_def.ParseFromString(f.read())
- tf.import_graph_def(output_graph_def, name='') # 注意这里不能加名字
- with tf.Session() as sess:
- sess.run(tf.global_variables_initializer())
- for node in output_graph_def.node:
- print(node.name)
- char_input = sess.graph.get_tensor_by_name("CharInputs:0")
- length = sess.graph.get_tensor_by_name("Sum:0")
- dropout = sess.graph.get_tensor_by_name("Dropout:0")
- logit = sess.graph.get_tensor_by_name("logits/Reshape:0")
- tran = sess.graph.get_tensor_by_name("crf_loss/transitions:0")
- while True:
- line = input("请输入测试句子:")
- _, chars, tags = input_from_line(line)
- print(chars)
- lengths, scores, tran_ = sess.run([length,logit,tran],feed_dict={char_input:np.asarray(chars),
- dropout:1.0
- } )
- batch_paths = decode(scores, lengths, tran_)
- tags = batch_paths[0] # batch_paths[0][:lengths] 错误
- result = result_to_json(line, tags)
- print(result)
- if __name__ == "__main__":
- # train()
- # evaluate_line()
- predict()
|