12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- @author: bidikeji
- @time: 2023/3/27 10:19
- """
- from BiddingKG.dl.product.product_model import Product_Model
- import os
- import re
- import time
- import pandas as pd
- import tensorflow as tf
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
- def predict():
- ckpt_path = "model"
- import json
- with tf.Session() as sess:
- model = Product_Model()
- sess.run(tf.global_variables_initializer())
- ckpt = tf.train.get_checkpoint_state(ckpt_path)
- # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch16_f10.8000_loss1.0775.ckpt')
- # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch7_f10.7998_loss1.0508.ckpt')
- model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch22_f10.7923_loss1.1039.ckpt') # 整理数据后再次训练
- # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch18_f10.8000_loss1.1276.ckpt') # 新
- # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch5_f10.6855_loss1.3800.ckpt') # 旧
- t1 = time.time()
- print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
- # df = pd.read_csv(os.path.dirname(__file__) + '/data/df_test.csv') #../test/
- df = pd.read_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx')
- print('公告数量:', len(df))
- df.fillna('', inplace=True)
- # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
- df.reset_index(drop=True, inplace=True)
- rs = []
- for i in df.index:
- text = df.loc[i, 'text']
- # result = model.evaluate_line(sess, text)
- # print(result[0][1])
- # rs.append(json.dumps(result[0][1], ensure_ascii=False))
- tmp = []
- for line in text.split('。'):
- # line = re.sub('[^\w]', '', line)
- # if len(line) < 5:
- # continue
- result = model.evaluate_line(sess, line)
- # print(result[0][1])
- tmp.extend(result[0][1])
- rs.append(json.dumps(tmp, ensure_ascii=False))
- df['predict_new'] = pd.Series(rs)
- df.to_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx', index=False)
- print('耗时: ', time.time()-t1)
- return df
- def 统计准确率(df):
- import json
- # df = pd.read_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx')
- df['pr'] = df['predict_new'].apply(lambda x:set([it[0] for it in json.loads(x)]))
- df['lb'] = df['lbset'].apply(lambda x: set(json.loads(x)))
- df['pos'] = df.apply(lambda x:1 if x['pr']==x['lb'] else 0, axis=1)
- eq = lb = pr = 0
- for i in df.index:
- pred = df.loc[i, 'pr']
- label = df.loc[i, 'lb']
- lb += len(label)
- pr += len(pred)
- eq += len(pred&label)
- acc = eq/pr
- recall = eq/lb
- f1 = acc*recall*2/(acc+recall)
- print('准确率:%.4f,召回率:%.4f,F1:%.4f'%(acc, recall, f1)) # 准确率:0.6489,召回率:0.8402,F1:0.7323
- # df.to_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx')
- if __name__ == "__main__":
- df = predict()
- 统计准确率(df)
|