#!/usr/bin/python3 # -*- coding: utf-8 -*- # @Author : bidikeji # @Time : 2021/1/13 0013 14:03 from BiddingKG.dl.product.product_model import Product_Model from BiddingKG.dl.product.data_util import BatchManager, get_label_data, id_to_tag, input_from_line, decode, result_to_json import numpy as np import tensorflow as tf import random import pickle import os def train(): # all_data = get_label_data() # random.shuffle(all_data) # train_data = all_data[:int(len(all_data)*0.85)] # dev_data = all_data[int(len(all_data)*0.85):] # with open('data/train_data2.pkl', 'wb') as f: # pickle.dump(train_data, f) # with open('data/dev_data2.pkl', 'wb') as f: # pickle.dump(dev_data, f) with open('data/train_data2.pkl', 'rb') as f: train_data = pickle.load(f) with open('data/dev_data2.pkl', 'rb') as f: dev_data = pickle.load(f) train_manager = BatchManager(train_data, batch_size=128) dev_manager = BatchManager(dev_data, batch_size=64) tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True steps_per_epoch = train_manager.len_data ckpt_path = "model" with tf.Session(config=tf_config) as sess: model = Product_Model() sess.run(tf.global_variables_initializer()) # ckpt = tf.train.get_checkpoint_state(ckpt_path) # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): # model.saver.restore(sess, ckpt.model_checkpoint_path) # print("从文件加载原来模型数据",ckpt.model_checkpoint_path) print('准备训练数据') loss = [] mix_loss = 1000 max_f1 = 0 for i in range(100): print('epochs:',i) # model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag) # break for batch in train_manager.iter_batch(shuffle=True): # print('batch:',len(batch)) # step, batch_loss = model.run_step(sess, True, batch) step, batch_loss = model.run_step(sess, 'train', batch) loss.append(batch_loss) if step % 10 == 0: iteration = step // steps_per_epoch + 1 print('iter:{} step:{} loss:{}'.format(iteration, step, np.mean(loss))) if i >= 50 or i%5==0: f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag) print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss)) if max_f1 < f1: model.saver.save(sess, os.path.join(ckpt_path, "ner2.ckpt")) print("model save .bast f1 is %.4f" % f1) max_f1 = f1 # if np.mean(loss)