#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ @author: bidikeji @time: 2023/3/27 10:19 """ from BiddingKG.dl.product.product_model import Product_Model import os import re import time import pandas as pd import tensorflow as tf os.environ['CUDA_VISIBLE_DEVICES'] = '-1' def predict(): ckpt_path = "model" import json with tf.Session() as sess: model = Product_Model() sess.run(tf.global_variables_initializer()) ckpt = tf.train.get_checkpoint_state(ckpt_path) # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch16_f10.8000_loss1.0775.ckpt') # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch7_f10.7998_loss1.0508.ckpt') model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch22_f10.7923_loss1.1039.ckpt') # 整理数据后再次训练 # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch18_f10.8000_loss1.1276.ckpt') # 新 # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch5_f10.6855_loss1.3800.ckpt') # 旧 t1 = time.time() print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs) # df = pd.read_csv(os.path.dirname(__file__) + '/data/df_test.csv') #../test/ df = pd.read_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx') print('公告数量:', len(df)) df.fillna('', inplace=True) # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx') df.reset_index(drop=True, inplace=True) rs = [] for i in df.index: text = df.loc[i, 'text'] # result = model.evaluate_line(sess, text) # print(result[0][1]) # rs.append(json.dumps(result[0][1], ensure_ascii=False)) tmp = [] for line in text.split('。'): # line = re.sub('[^\w]', '', line) # if len(line) < 5: # continue result = model.evaluate_line(sess, line) # print(result[0][1]) tmp.extend(result[0][1]) rs.append(json.dumps(tmp, ensure_ascii=False)) df['predict_new'] = pd.Series(rs) df.to_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx', index=False) print('耗时: ', time.time()-t1) return df def 统计准确率(df): import json # df = pd.read_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx') df['pr'] = df['predict_new'].apply(lambda x:set([it[0] for it in json.loads(x)])) df['lb'] = df['lbset'].apply(lambda x: set(json.loads(x))) df['pos'] = df.apply(lambda x:1 if x['pr']==x['lb'] else 0, axis=1) eq = lb = pr = 0 for i in df.index: pred = df.loc[i, 'pr'] label = df.loc[i, 'lb'] lb += len(label) pr += len(pred) eq += len(pred&label) acc = eq/pr recall = eq/lb f1 = acc*recall*2/(acc+recall) print('准确率:%.4f,召回率:%.4f,F1:%.4f'%(acc, recall, f1)) # 准确率:0.6489,召回率:0.8402,F1:0.7323 # df.to_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx') if __name__ == "__main__": df = predict() 统计准确率(df)